使用RGCN对异构图进行分类遇到的问题(待解决)

使用RGCN对异构图进行分类遇到的问题(待解决),第1张

最近在尝试使用RGCN对异构图进行分类的时候遇到了一些问题,目前尚未解决。
源码:

import torch
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import dgl.nn.pytorch.conv.relgraphconv
import torch.nn.functional as F
import numpy as np
import dgl.nn.pytorch as dglnn
def create_graph():
    n_users = 1000
    n_items = 500
    n_follows = 3000
    n_clicks = 5000
    n_dislikes = 500
    n_hetero_features = 10

    follow_src = np.random.randint(0, n_users, n_follows)
    follow_dst = np.random.randint(0, n_users, n_follows)
    click_src = np.random.randint(0, n_users, n_clicks)
    click_dst = np.random.randint(0, n_items, n_clicks)
    dislike_src = np.random.randint(0, n_users, n_dislikes)
    dislike_dst = np.random.randint(0, n_items, n_dislikes)

    hetero_graph = dgl.heterograph({
        ('user', 'follow', 'user'): (follow_src, follow_dst),
        ('user', 'followed-by', 'user'): (follow_dst, follow_src),
        ('user', 'click', 'item'): (click_src, click_dst),
        ('item', 'clicked-by', 'user'): (click_dst, click_src),
        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

    hetero_graph.nodes['user'].data['feat'] = torch.randn(n_users, n_hetero_features)
    hetero_graph.nodes['item'].data['feat'] = torch.randn(n_items, n_hetero_features)

    label = np.random.randint(0, 5, 1)
    return hetero_graph, torch.LongTensor(label)

dataset = [create_graph() for i in range(1000)]

def collate(samples):
    # graphs, labels = map(list, zip(*samples))
    graphs, labels = zip(*samples)
    batched_graph = dgl.batch(graphs)
    batched_labels = torch.tensor(labels)
    return batched_graph, batched_labels

dataloader = DataLoader(
    dataset,
    batch_size=1024,
    collate_fn=collate,
    drop_last=False,
    shuffle=True)

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats,weight=False)
            for rel in rel_names}, aggregate='sum')       #(5,10)
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats,weight=False)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs是节点的特征
        print("测试inputs")
        print(inputs)    #{'item':tensor,'user':tensor}
        h = self.conv1(graph, inputs)  #
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()

        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)


etypes = dataset[0][0].etypes   #['clicked-by', 'disliked-by', 'click', 'dislike', 'follow', 'followed-by']
model = HeteroClassifier(10, 20, 5, etypes)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
    for batched_graph, labels in dataloader:
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels.squeeze(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())


报错信息:

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/792330.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-05
下一篇 2022-05-05

发表评论

登录后才能评论

评论列表(0条)

保存