参考链接
- https://docs.dgl.ai/guide/nn.html#guide-nn
如果DGL没有你想要的GNN模块,可以根据自己的需求定义(感觉应该放在后面讲,容易劝退像我这种小白)。本节以GraphSAGE为例。
与pytorch类似,构造函数完成以下几个任务:
- 设置选项
- 注册可学习的参数或者子模块
- 初始化参数
演示代码如下:
import torch.nn as nn
from dgl.utils import expand_as_pair
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
由于一个节点既可以是源节点,亦可以是目标节点,所以需要(“self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)”)把特征分成两部分,一部分作为源节点的时候用,一部分作为目标节点的时候用。
注册可学习的参数或者子模块
“self._aggre_type”是消息聚合函数类型,常见的有“mean”、“sum”、“max”、“min”和“lstm”等。
# 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
初始化参数
如果是“gcn”聚合类型,需要单独考虑。这一点在上面的代码“if aggregator_type in [‘mean’, ‘max_pool’, ‘lstm’]”以及后面的代码都有体现。
def reset_parameters(self):
"""重新初始化可学习的参数"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
接下来是forward()函数:
消息产生和传递
公式为
h
N
(
d
s
t
)
(
l
+
1
)
=
a
g
g
r
e
g
a
t
e
(
{
h
s
r
c
(
l
)
,
∀
s
r
c
∈
N
(
d
s
t
)
}
)
h_{N\left ( dst\right )}^{\left ( l+1\right )}=aggregate\left ( \left \{h_{src}^{\left ( l\right )},\forall src\in N\left ( dst\right )\right \}\right )
h
N
(
d
s
t
)
(
l
+
1
)
=
a
g
g
r
e
g
a
t
e
(
{
h
s
r
c
(
l
)
,
∀
s
r
c
∈
N
(
d
s
t
)
}
)
消息产生和消息传递过程一起写在了“graph.update_all()”
从代码上看:
- “mean”聚合类型,就是将目标节点邻居的特征求平均作为邻居的消息;
- “gcn”聚合类型,就是将目标节点邻居特征的和与自身的特征求和后除以入度,作为邻居的消息;
- “max_pool”聚合类型,就是将原始特征经过一次非线性变换后作为源节点特征,然后将目标节点邻居中最大的特征作为邻居的消息。
def forward(self, graph, feat):
with graph.local_scope():
# 指定图类型,然后根据图类型扩展输入特征
feat_src, feat_dst = expand_as_pair(feat, graph)
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# 除以入度
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
消息聚合
公式为
h
d
s
t
(
l
+
1
)
=
σ
(
W
⋅
c
o
n
c
a
t
(
h
d
s
t
(
l
)
,
h
N
(
d
s
t
)
(
l
+
1
)
)
+
b
)
h_{dst}^{\left ( l+1\right )}=\sigma \left ( W\cdot concat\left ( h_{dst}^{\left ( l\right )},h_{N\left ( dst\right )}^{\left ( l+1\right )}\right )+b\right )
h
d
s
t
(
l
+
1
)
=
σ
(
W
⋅
c
o
n
c
a
t
(
h
d
s
t
(
l
)
,
h
N
(
d
s
t
)
(
l
+
1
)
)
+
b
)
h
d
s
t
(
l
+
1
)
=
n
o
r
m
(
h
d
s
t
(
l
+
1
)
)
h_{dst}^{\left ( l+1\right )}=norm\left ( h_{dst}^{\left ( l+1\right )}\right )
h
d
s
t
(
l
+
1
)
=
n
o
r
m
(
h
d
s
t
(
l
+
1
)
)
“gcn”聚合类型并不会将目标节点自己的特征做一次线性变换,而是只将邻居的消息做一次线性变换作为新的消息。
# GraphSAGE中gcn聚合不需要fc_self
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
似乎首次出现了“graph.srcdata[‘h’]”和“graph.dstdata[‘h’]”访问节点数据的方式,印证了前面提到的输入特征分为两部分的思想。猜测这样也是为了与“g.ndata[‘h’]”的方式访问同名特征区分。
输出更新后的特征
对新消息加上激活函数和归一化作为新的特征
# 激活函数
if self.activation is not None:
rst = self.activation(rst)
# 归一化
if self.norm is not None:
rst = self.norm(rst)
return rst
“完整”代码
将文档提供的代码排版后如下,还存在几个问题:
- 不支持“lstm”聚合类型
- 目标节点自身的消息“h_self”并没有定义
- 文档中公式书写有些许问题,已经更正
import torch.nn as nn
from dgl.utils import expand_as_pair
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
# 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""重新初始化可学习的参数"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def forward(self, graph, feat):
with graph.local_scope():
# 指定图类型,然后根据图类型扩展输入特征
feat_src, feat_dst = expand_as_pair(feat, graph)
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# 除以入度
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE中gcn聚合不需要fc_self
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# 激活函数
if self.activation is not None:
rst = self.activation(rst)
# 归一化
if self.norm is not None:
rst = self.norm(rst)
return rst
附
代码中有一个“expand_as_pair()”函数,用于根据图类型,将输入特征分为两部分。具体细节如下(目前只需要看同构图的情况就行,其实就是简单的复制了一份):
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# 二分图的情况
return input_
elif g is not None and g.is_block:
# 子图块的情况
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# 同构图的情况
return input_, input_