跟着官方文档学DGL框架第五天——自定义GNN模块(GraphSAGE实现)

  • Post author:
  • Post category:其他




参考链接

  1. https://docs.dgl.ai/guide/nn.html#guide-nn

如果DGL没有你想要的GNN模块,可以根据自己的需求定义(感觉应该放在后面讲,容易劝退像我这种小白)。本节以GraphSAGE为例。

与pytorch类似,构造函数完成以下几个任务:

  1. 设置选项
  2. 注册可学习的参数或者子模块
  3. 初始化参数

演示代码如下:

import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

由于一个节点既可以是源节点,亦可以是目标节点,所以需要(“self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)”)把特征分成两部分,一部分作为源节点的时候用,一部分作为目标节点的时候用。



注册可学习的参数或者子模块

“self._aggre_type”是消息聚合函数类型,常见的有“mean”、“sum”、“max”、“min”和“lstm”等。

        # 聚合类型:mean、max_pool、lstm、gcn
        if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'max_pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'max_pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()



初始化参数

如果是“gcn”聚合类型,需要单独考虑。这一点在上面的代码“if aggregator_type in [‘mean’, ‘max_pool’, ‘lstm’]”以及后面的代码都有体现。

    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

接下来是forward()函数:



消息产生和传递

公式为




h

N

(

d

s

t

)

(

l

+

1

)

=

a

g

g

r

e

g

a

t

e

(

{

h

s

r

c

(

l

)

,

s

r

c

N

(

d

s

t

)

}

)

h_{N\left ( dst\right )}^{\left ( l+1\right )}=aggregate\left ( \left \{h_{src}^{\left ( l\right )},\forall src\in N\left ( dst\right )\right \}\right )







h











N




(



d


s


t



)














(



l


+


1



)























=








a


g


g


r


e


g


a


t


e






(





{





h











s


r


c












(



l



)





















,







s


r


c









N





(


d


s


t


)




}





)






消息产生和消息传递过程一起写在了“graph.update_all()”

从代码上看:

  1. “mean”聚合类型,就是将目标节点邻居的特征求平均作为邻居的消息;
  2. “gcn”聚合类型,就是将目标节点邻居特征的和与自身的特征求和后除以入度,作为邻居的消息;
  3. “max_pool”聚合类型,就是将原始特征经过一次非线性变换后作为源节点特征,然后将目标节点邻居中最大的特征作为邻居的消息。
    def forward(self, graph, feat):
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)

        import dgl.function as fn
        import torch.nn.functional as F
        from dgl.utils import check_eq_shape

        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'max_pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))



消息聚合

公式为




h

d

s

t

(

l

+

1

)

=

σ

(

W

c

o

n

c

a

t

(

h

d

s

t

(

l

)

,

h

N

(

d

s

t

)

(

l

+

1

)

)

+

b

)

h_{dst}^{\left ( l+1\right )}=\sigma \left ( W\cdot concat\left ( h_{dst}^{\left ( l\right )},h_{N\left ( dst\right )}^{\left ( l+1\right )}\right )+b\right )







h











d


s


t












(



l


+


1



)























=








σ






(



W









c


o


n


c


a


t






(




h











d


s


t












(



l



)





















,





h











N




(



d


s


t



)














(



l


+


1



)






















)






+




b



)









h

d

s

t

(

l

+

1

)

=

n

o

r

m

(

h

d

s

t

(

l

+

1

)

)

h_{dst}^{\left ( l+1\right )}=norm\left ( h_{dst}^{\left ( l+1\right )}\right )







h











d


s


t












(



l


+


1



)























=








n


o


r


m






(




h











d


s


t












(



l


+


1



)






















)






“gcn”聚合类型并不会将目标节点自己的特征做一次线性变换,而是只将邻居的消息做一次线性变换作为新的消息。

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

似乎首次出现了“graph.srcdata[‘h’]”和“graph.dstdata[‘h’]”访问节点数据的方式,印证了前面提到的输入特征分为两部分的思想。猜测这样也是为了与“g.ndata[‘h’]”的方式访问同名特征区分。



输出更新后的特征

对新消息加上激活函数和归一化作为新的特征

        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst



“完整”代码

将文档提供的代码排版后如下,还存在几个问题:

  1. 不支持“lstm”聚合类型
  2. 目标节点自身的消息“h_self”并没有定义
  3. 文档中公式书写有些许问题,已经更正
import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation


        # 聚合类型:mean、max_pool、lstm、gcn
        if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'max_pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'max_pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()



    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def forward(self, graph, feat):
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)

        import dgl.function as fn
        import torch.nn.functional as F
        from dgl.utils import check_eq_shape

        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'max_pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst




代码中有一个“expand_as_pair()”函数,用于根据图类型,将输入特征分为两部分。具体细节如下(目前只需要看同构图的情况就行,其实就是简单的复制了一份):

def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # 二分图的情况
        return input_
    elif g is not None and g.is_block:
        # 子图块的情况
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # 同构图的情况
        return input_, input_



版权声明:本文为beilizhang原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。