pytorch geometric教程三 GraphSAGE源码详解+实战

pytorch geometric教程三 GraphSAGE源码详解+实战,第1张

pytorch geometric教程三 GraphSAGE源码详解+实战

pytorch geometric教程三 GraphSAGE源码详解+实战
  • pytorch geometric教程三 GraphSAGE源码详解&实战
  • 原理回顾
    • paper公式
    • 代码实现
  • SAGE代码(SAGEConv)
    • __init__
      • 邻域聚合方式
      • 参数含义
      • init主体
    • forward函数
      • 参数
      • forward主体
    • 消息传递
      • 一,edge_index为Tensor
      • 二,edge_index为SparseTensor
  • 实战
    • 定义模型
    • 模型调用

pytorch geometric教程三 GraphSAGE源码详解&实战

这一篇是建立在你已经对pytorch geometric消息传递&跟新的原理有一定了解的基础上。如果没有的话,也没关系,可以看看这篇关于pytorch geometric消息传递&更新的博文(pytorch geometric 消息传递源码详解(MESSAGE PASSING)+实例)。
GraphSAGE图算法中的SAGE是sample and aggregate的缩写。这一篇博文讲的是SAGEConv代码中实现的卷积部分,也就是SAGE中aggregate部分。sample的部分在pytorch geometric中是在torch_geometric.loader.NeighborSampler实现的(老一点的版本是在torch_geometric.data.NeighborSampler)。sample方法极大地增加了图算法的可拓展性,实现了minibatch跟新梯度,使得图的大规模计算成为可能。sample部分指路这篇博文(码字ing)。下文提到的SAGE都默认只有aggregate部分。

原理回顾

先回顾一下SAGE的原理:

paper公式

x i ′ = W ⋅ c o n c a t ( A g g r e g a t e j ∈ N ( i ) x j , x i ) mathbf{x}^{prime}_i = mathbf{W} cdot concat( mathrm{Aggregate}_{j in mathcal{N(i)}} mathbf{x}_j, mathbf{x}_i) xi′​=W⋅concat(Aggregatej∈N(i)​xj​,xi​)
其中点 j j j是点 i i i的邻居。直观理解,对邻居特征进行某种方式聚合后,和节点自身特征concat,再进行一个维度变换。其中聚合方式有mean, Pooling, LSTM。

代码实现

x i ′ = W 1 x i + W 2 ⋅ A g g r e g a t e j ∈ N ( i ) x j mathbf{x}^{prime}_i = mathbf{W}_1 mathbf{x}_i + mathbf{W}_2 cdot mathrm{Aggregate}_{j in mathcal{N(i)}} mathbf{x}_j xi′​=W1​xi​+W2​⋅Aggregatej∈N(i)​xj​
这是pytorch geometric实现的方式。注意到代码实现中,concat消失了, W 1 mathbf{W}_1 W1​, W 2 mathbf{W}_2 W2​取代了 W mathbf{W} W。其实这两个结果是一样的。可以看一下示意图:

这里直观理解SAGE,将邻居对应的特征聚合后,进行一个维度变换,再加上节点自身经过维度变换的特征,就是节点最终生成的embedding。相比GCN,节点 i i i和邻居 j j j使用了不同的 W mathbf{W} W,投射到了不同的特征空间,这一点大大加强了模型的表达能力。mean是代码中默认的邻居聚合方式,还可以改成max, add。原paper中的Pooling, LSTM聚合方式,SAGEConv是没有实现的。
再次强调一下,SAGEConv代码中的邻居就是你传入的邻居,不管是使用NeighborSampler等方式对邻居进行采样过的邻居还是未采样的所有邻居,它只管接收你传入的邻居,邻居采样不在这里实现。

SAGE代码(SAGEConv) init
class SAGEConv(MessagePassing):    
        def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, normalize: bool = False,
                 root_weight: bool = True,
                 bias: bool = True, **kwargs):  # yapf: disable
        kwargs.setdefault('aggr', 'mean')
        super(SAGEConv, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()
邻域聚合方式

kwargs.setdefault('aggr', 'mean')检查关键字参数中是否定义了邻域聚合方式,也就是是否包含名为aggr的key。如果没有的话,采用默认的mean聚合方式,也就是邻居特征求和平均。在我们定义model的时候,我们可以通过参数aggr = add, mean, max来选择邻居特征聚合方式。

参数含义

另外解释下各个参数的含义:

  • in_channels: Union[int, Tuple[int, int]]:输入原始特征或者隐含层embedding的维度。如果是-1,则根据传入的x来推断特征维度。注意in_channels可以是一个整数,也可以是两个整数组成的tuple,分别对应source节点和target节点的特征维度,也是参数 W 2 mathbf{W}_2 W2​和 W 1 mathbf{W}_1 W1​的shape[0]。
  • out_channels:输出embedding的维度
  • normalize:默认是False, 如果是True的话,对output进行 ℓ 2 ell_2 ℓ2​归一化, x i ′ ∥ x i ′ ∥ 2 frac{mathbf{x}^{prime}_i} {| mathbf{x}^{prime}_i |_2} ∥xi′​∥2​xi′​​
  • improved: 默认是False, 如果是True的话,则 A ^ = A + 2 I mathbf{hat{A}} = mathbf{A} + 2mathbf{I} A^=A+2I,增强了自身的权重。
  • root_weight: 默认是True。如果是False的话,output不会加上节点自身特征转换维度后的值,就是代码实现公式中的 W 1 x i mathbf{W}_1 mathbf{x}_i W1​xi​。
  • bias:默认是True,如果是False的话,layer中没有bias项。
init主体

init函数包括了 W 1 mathbf{W}_1 W1​和 W 2 mathbf{W}_2 W2​的定义。

forward函数
	def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, size=size)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out
参数
  • x:Union[Tensor, OptPairTensor]:可以是Tensor,也可以是OptPairTensor (pyg定义的tuple of Tensor)。
    当图是bipartite的时候,x是OptPairTensor ,source节点特征对应x[0],在代码中赋值给x_l变量,target节点特征对应x[1],在代码中赋值给 x_r。这也是作者把 W 1 mathbf{W}_1 W1​定义为lin_r, W 2 mathbf{W}_2 W2​定义为lin_l的原因。
    另外pyg中实现sample是通过NeighborSampler返回bipartite子图完成的。所以SAGEConv支持邻居采样的结果,与NeighborSampler一起使用,可以实现minibatch训练大规模图数据。
  • edge_index: Adj: Adj是pyg定义的邻接矩阵类型,可以是Tensor,也可以是SparseTensor。
forward主体

代码的逻辑是比较清晰的:

  • 如果x是Tensor的话,将x: OptPairTensor = (x, x)
    为了兼容同构图与bipartite图,将x写成bipartite图的形式。而在同构图中,source节点和target节点对应的x是一样的,所以有(x, x)。
  • 调用propagate函数对进行消息传递和聚合,输出的out对应公式中的 A g g r e g a t e j ∈ N ( i ) x j mathrm{Aggregate}_{j in mathcal{N(i)}} mathbf{x}_j Aggregatej∈N(i)​xj​。对progagate函数底层逻辑有疑问的小伙伴,可以参考之前的博文。
  • self.lin_l(out)对应 W 2   o u t mathbf{W}_2 , mathbf{out} W2​out
  • self.lin_r(x_r)对应 W 1   X mathbf{W}_1 , mathbf{X} W1​X
消息传递
    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x[0], reduce=self.aggr)
一,edge_index为Tensor

这里不明白的小伙伴可以先看这篇博文(pytorch geometric 消息传递原理详解(MESSAGE PASSING)+实例)
edge_index为Tensor的时候,propagate调用message和aggregate实现消息传递和更新。
这里message函数对邻居特征没有任何处理,只是进行了传递,所以最终propagate函数只是对邻居特征进行了aggregate。

二,edge_index为SparseTensor

edge_index为SparseTensor的时候,propagate函数会在message_and_aggregate被定义的情况下被优先调用,代替message和aggregate。
这里message_and_aggregate直接调用类似矩阵计算matmul(adj_t, x[0], reduce=self.aggr)。x[0]是source节点的特征。matmul来自于torch_sparse,除了类似常规的矩阵相乘外,还给出了可选的reduce,所以除了add,mean和max也是可以在这里实现的。

实战 定义模型

pytorch geometric的卷积层调用还是挺简单的,下面是一个两层的SAGE。

import torch
import torch.nn.functional as F
from torch_geometric.nn.conv import SAGEConv

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.):
        super(SAGE, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        
        self.dropout = dropout
        
    def reset_parameters():
        for conv in self.convs:
            conv.reset_parameters()
            
    def forward(self, x, edge_index):
        x = self.convs[0](x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[1](x, edge_index)
        
        return x.log_softmax(dim=-1)
模型调用

接下来,我们用Cora数据集尝试一下。

#读取数据
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

transform = T.ToSparseTensor()
# 这里加上了ToSparseTensor(),所以边信息是以adj_t形式存储的,如果没有这个变换,则是edge_index
dataset = Planetoid(name='Cora', root=r'./dataset/Cora', transform=transform)
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()

model = SAGE(in_channels=dataset.num_features, hidden_channels=128, out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    model.train()
    
    optimizer.zero_grad()
    out = model(data.x, data.adj_t)[data.train_mask] #前面我们提到了,SAGE是实现了edge_index和adj_t两种形式的
    loss = F.nll_loss(out, data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    return loss.item()

@torch.no_grad()
def test():
    model.eval()
    
    out = model(data.x, data.adj_t)
    y_pred = out.argmax(axis=-1)
    
    correct = y_pred == data.y
    train_acc = correct[data.train_mask].sum().float()/data.train_mask.sum()
    valid_acc = correct[data.val_mask].sum().float()/data.val_mask.sum()
    test_acc = correct[data.test_mask].sum().float()/data.test_mask.sum()
    
    return train_acc, valid_acc, test_acc 

#跑10个epoch看一下模型效果
for epoch in range(20):
    loss = train()
    train_acc, valid_acc, test_acc = test()
    print(f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train_acc: {100 * train_acc:.3f}%, '
                              f'Valid_acc: {100 * valid_acc:.3f}% '
                              f'Test_acc: {100 * test_acc:.3f}%')
Epoch: 00, Loss: 1.9485, Train_acc: 67.143%, Valid_acc: 39.400% Test_acc: 37.200%
Epoch: 01, Loss: 1.8761, Train_acc: 90.714%, Valid_acc: 51.400% Test_acc: 51.700%
Epoch: 02, Loss: 1.8071, Train_acc: 99.286%, Valid_acc: 59.600% Test_acc: 60.800%
Epoch: 03, Loss: 1.7365, Train_acc: 99.286%, Valid_acc: 66.200% Test_acc: 67.100%
Epoch: 04, Loss: 1.6615, Train_acc: 100.000%, Valid_acc: 70.000% Test_acc: 70.900%
Epoch: 05, Loss: 1.5816, Train_acc: 100.000%, Valid_acc: 73.800% Test_acc: 73.700%
Epoch: 06, Loss: 1.4967, Train_acc: 100.000%, Valid_acc: 76.800% Test_acc: 74.900%
Epoch: 07, Loss: 1.4075, Train_acc: 100.000%, Valid_acc: 77.600% Test_acc: 76.400%
Epoch: 08, Loss: 1.3151, Train_acc: 100.000%, Valid_acc: 77.400% Test_acc: 76.700%
Epoch: 09, Loss: 1.2208, Train_acc: 100.000%, Valid_acc: 77.800% Test_acc: 77.100%
Epoch: 10, Loss: 1.1255, Train_acc: 100.000%, Valid_acc: 77.800% Test_acc: 77.400%
Epoch: 11, Loss: 1.0306, Train_acc: 100.000%, Valid_acc: 78.000% Test_acc: 78.000%
Epoch: 12, Loss: 0.9371, Train_acc: 100.000%, Valid_acc: 78.000% Test_acc: 77.700%
Epoch: 13, Loss: 0.8464, Train_acc: 100.000%, Valid_acc: 78.200% Test_acc: 77.500%
Epoch: 14, Loss: 0.7591, Train_acc: 100.000%, Valid_acc: 77.800% Test_acc: 77.600%
Epoch: 15, Loss: 0.6764, Train_acc: 100.000%, Valid_acc: 77.800% Test_acc: 77.600%
Epoch: 16, Loss: 0.5990, Train_acc: 100.000%, Valid_acc: 77.800% Test_acc: 77.400%
Epoch: 17, Loss: 0.5273, Train_acc: 100.000%, Valid_acc: 77.400% Test_acc: 77.400%
Epoch: 18, Loss: 0.4617, Train_acc: 100.000%, Valid_acc: 77.400% Test_acc: 77.000%
Epoch: 19, Loss: 0.4024, Train_acc: 100.000%, Valid_acc: 77.200% Test_acc: 77.100%

这样我们一个GraphSAGE模型就初步完成啦!我们看到,在经过10个epoch后,test集的acc最高达到了78.0%。对比上篇的GCN代码,可以看到改动只是把卷积从GCNConv换到了SAGEConv,所以pytorch geometric用起来还是很方便的。
欢迎评论交流,转载请注明出处哦!

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5571729.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-14
下一篇 2022-12-14

发表评论

登录后才能评论

评论列表(0条)

保存