使用图神经网络预测药物-药物相互作用

使用图神经网络预测药物-药物相互作用,第1张

使用图神经网络预测药物-药物相互作用

了解药物如何相互作用是医学研究和实践的关键问题。图形机器学习(GraphML)领域可用于以高可信度回答有关药物 - 药物相互作用的问题。本次学习通过GraphSAGE图卷积来预测DDI(药物-药物相互作用)。

使用图神经网络预测药物-药物相互作用
  • 使用图神经网络预测药物-药物相互作用
    • 1.数据
    • 2.定义词汇意义
    • 3.学习目标
    • 4.加载数据集
    • 5.消息传递概述
    • 6.数据集拆分
    • 7.模型构建
      • GraphSAGE构建
      • 链接预测器
    • 8.训练

1.数据

数据来源于Open Graph Benchmark (OGB) 是用于图形机器学习任务的开源基准数据集的集合。我们在本文中的重点是ogbl-ddi数据集,如上所述,它由单个药物 - 药物相互作用(DDI)网络组成。

我们在数学上将 DDI 图定义为 G = (V, E),其中 V 是节点集,E 是边的集合。图中的每个节点 v ∈ V 代表 FDA 批准或实验药物。两个节点u和v之间存在边缘(u,v)表明两种药物相互作用,使得同时服用两种药物的效果与药物彼此独立作用的预期效果有很大不同[1]。例如,靶向相同蛋白质的两种药物可能具有显着的相互作用。

2.定义词汇意义

首先,一些词汇,‘正边’是数据集中存在的边。每个正边代表由边缘端点表示的两种药物之间的已知显着相互作用。
如果两种药物不相互作用(即,它们一起服用与单独服用时具有相同的效果),则图表中不会存在边缘。这些“单独的节点”被称为负边;换句话说,就是图中不存在的边。

3.学习目标

本次学习的目标是开发一个图形机器学习模型来解决链接预测任务:给定两个药物作为输入,我们希望预测两种药物是否相互作用,即图中的这两个节点之间是否应该存在边缘。这应该允许我们通过将缺失的边缘理解为正边或负边来完成数据集。

4.加载数据集

按照 OGB 网站的示例,我们可以将 DDI 数据集加载到 PyTorch Geometric (PyG) 中:

from ogb.linkproppred import PygLinkPropPredDataset

dataset_name = 'ogbl-ddi'
dataset = PygLinkPropPredDataset(name=dataset_name)
print(f'The {dataset_name} dataset has {len(dataset)} graph(s).')
ddi_G = dataset[0]
print(f'DDI 图: {ddi_G}')
print(f'节点数量 |V|: {ddi_G.num_nodes}')
print(f'边的数量 |E|: {ddi_G.num_edges}')
print(f'无向图?: {ddi_G.is_undirected()}')
print(f'节点平均度: {ddi_G.num_edges / ddi_G.num_nodes:.2f}')
print(f'节点特征: {ddi_G.num_node_features}')
print(f'有孤立点?: {ddi_G.has_isolated_nodes()}')
print(f'有自循环?: {ddi_G.has_self_loops()}')
输出
Using backend: pytorch
Downloading http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip
Downloaded 0.04 GB: 100%|██████████████████████████████████████████████████████████████| 46/46 [00:31<00:00,  1.45it/s]
Extracting dataset\ddi.zip
Processing...
Loading necessary files...
This might take a while.
Processing graphs...
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 58.81it/s]
Converting graphs into PyG objects...
100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]
Saving...
The ogbl-ddi dataset has 1 graph(s).
Done!
DDI 图: Data(num_nodes=4267, edge_index=[2, 2135822])
节点数量 |V|: 4267
边的数量 |E|: 2135822
无向图?: True
节点平均度: 500.54
节点特征: 0
有孤立点?: False
有自循环?: False

注意:DDI数据集中,图没有节点特征,后续需要进行加工处理。

5.消息传递概述

每个 GNN“层”由在每个节点与其邻居之间传递的一轮此消息、每个节点接收的邻居消息的聚合以及使用聚合消息计算更新的嵌入来定义的。
消息传递是节点能够合并来自其局部邻域结构的信息以确定其自身嵌入的机制。它由生成消息的每个节点组成,该消息在回合中沿该节点的传出边缘传递给其他节点。
消息传递的PyG实现,以及其他知识可以看我这一篇PyG消息传递

6.数据集拆分

OGB 为我们提供了数据集拆分。正边是图中存在的边:集合 {(u,v)∈E} ,其中 u,v∈V ,负边是图中不存在的边:集合 {(u,v)∉E} 。 我们需要正边和负边来训练我们的链接预测模型。 数据集的拆分详情可以查看ogb官网,

split_edges = dataset.get_edge_split()
train_edges, valid_edges, test_edges = split_edges['train'], split_edges['valid'], split_edges['test']
print(train_edges)
print(f'训练集正边的数量: {train_edges["edge"].shape[0]}')
print(f'验证集正边的数量: {valid_edges["edge"].shape[0]}')
print(f'验证集负边的数量: {valid_edges["edge_neg"].shape[0]}')
print(f'测试集正边的数量: {test_edges["edge"].shape[0]}')
print(f'测试集负边的数量: {valid_edges["edge_neg"].shape[0]}')
输出
{'edge': tensor([[4039, 2424],
         [4039,  225],
         [4039, 3901],
         ...,
         [ 647,  708],
         [ 708,  338],
         [ 835, 3554]])}
训练集正边的数量: 1067911
验证集正边的数量: 133489
验证集负边的数量: 101882
测试集正边的数量: 133489
测试集负边的数量: 101882

ddi_graph.edge_index 的形状为 [2, 2 * E] 因为我们在 GNN 中使用 edge_index,这需要从 u 向 v 和 v 向 u 发送信息(因为我们将在下面看到)。药物1对药物2有作用,相反药物2对药物1也有作用,作用是相互的。

7.模型构建

模型有两个部分: 1)图神经网络生成节点嵌入 2) 输出链接预测概率的深度神经网络

GraphSAGE构建

import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling
from tqdm import trange

class GraphSAGE(torch.nn.Module):
    """使用 GraphSAGE 架构构建的图神经网络。"""

    def __init__(self, conv, in_channels, hidden_channels, out_channels, num_layers, dropout):
        '''
        in_channels:初始节点嵌入的维度。由于药物没有节点特征,我们将随机初始化这些向量。
        hidden_channels:中间节点嵌入的维度。隐藏层的维度。
        out_channels:输出节点嵌入的维度。
        num_layers:我们的 GNN 中的层数K。这是应用 GraphSAGE 运算符的次数。
        dropout:Dropout 应用于权重矩阵 W1 和 W2。
        '''
        super(GraphSAGE, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        assert (num_layers >= 2), 'Have at least 2 layers'##至少两层卷积
        # 在每一个layer中增加conv,上一层与下一层的维度必须一致
        # 我们还应用了归一化,之后输出节点嵌入。每个卷积层都是 L2 归一化的。
        self.convs.append(conv(in_channels, hidden_channels, normalize=True))
        for l in range(num_layers - 2):
            self.convs.append(conv(hidden_channels, hidden_channels, normalize=True))
        self.convs.append(conv(hidden_channels, out_channels, normalize=True))

        self.num_layers = num_layers
        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr):
        if edge_attr is not None: ## 如果有edge_attr
            return self.forward_with_edge_attr(x, edge_index, edge_attr)

        # x 是初始节点嵌入的矩阵,形状 [N, in_channels]
        for i in range(self.num_layers - 1):
            # 第 i 层进行消息传递和聚合
            x = self.convs[i](x, edge_index)
            # x 的形状为 [N, hidden_channels]
            # 通过非线性激活函数relu
            x = F.relu(x)

            x = F.dropout(x, p=self.dropout, training=self.training)

        # 生成最终嵌入, x 的形状为 [N, out_channels]
        x = self.convs[self.num_layers - 1](x, edge_index)
        return x
  
    def forward_with_edge_attr(self, x, edge_index, edge_attr):
        # x 是初始节点嵌入的矩阵,形状 [N, in_channels]
        for i in range(self.num_layers - 1):
            # 第 i 层进行消息传递和聚合
            x = self.convs[i](x, edge_index, edge_attr)
            # x 的形状为 [N, hidden_channels]
            # 通过非线性激活函数relu
            x = F.relu(x)

            x = F.dropout(x, p=self.dropout,training=self.training)

        # 生成最终嵌入, x 的形状为 [N, out_channels]
        x = self.convs[self.num_layers - 1](x, edge_index, edge_attr)
        return x
#设置参数
graphsage_in_channels = 256 
graphsage_hidden_channels = 256 
graphsage_out_channels = 256 
graphsage_num_layers = 2 
dropout = 0.5 

###注意,因为数据库ddi本身没有附带节点特征矩阵,所以我们要创立初始嵌入。torch.nn.Embedding
initial_node_embeddings = torch.nn.Embedding(ddi_graph.num_nodes, graphsage_in_channels).to(device)

##图节点特征向量形状为[N,in_channels]
initial_node_embeddings
输出
Embedding(4267, 256)
## 实例化模型GraphSAGE
graphsage_model = GraphSAGE(SAGEConv, graphsage_in_channels, 
                                 graphsage_hidden_channels,
                                 graphsage_out_channels,
                                 graphsage_num_layers, 
                                 dropout).to(device)
链接预测器
link_predictor_in_channels = graphsage_out_channels
link_predictor_hidden_channels = link_predictor_in_channels


class LinkPredictor(torch.nn.Module):
    """将两个输入转换为单个输出的通用网络。"""

    def __init__(self, in_channels, hidden_channels, dropout, out_channels=1,
                concat=lambda x, y: x * y):
        super(LinkPredictor, self).__init__()

        self.model = nn.Sequential(nn.Linear(in_channels, hidden_channels), nn.ReLU(), 
                                  nn.Dropout(p=dropout), nn.Linear(hidden_channels, out_channels), nn.Sigmoid())
        
        self.concat = concat
    
    def forward(self, u, v):
        x = self.concat(u, v)
        return self.model(x)

link_predictor = LinkPredictor(in_channels=link_predictor_in_channels, 
                               hidden_channels=link_predictor_hidden_channels, 
                               dropout=dropout).to(device)
8.训练
##训练我们的完整模型(GraphSAGE + LinkPredictor)
def train(graphsage_model, link_predictor, initial_node_embeddings, edge_index, 
          pos_train_edges, optimizer, batch_size, edge_attr=None):
    
    total_loss, total_examples = 0, 0

    # 设置我们的模型进行训练
    graphsage_model.train()
    link_predictor.train()

    # 迭代成批的训练边(“正边”)
    # (最后一次迭代的边数可能比 batch_size 少)
    for pos_samples in DataLoader(pos_train_edges, batch_size, shuffle=True):
        optimizer.zero_grad()

        # 运行 GraphSAGE 前向传递
        node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)
        
        #对由 attr:'edge_index'给出的图的随机对负边进行采样。
        # neg_samples 是一个尺寸为 [2, batch_size] 的张量
        neg_samples = negative_sampling(edge_index, 
                                        num_nodes=initial_node_embeddings.size(0),
                                        num_neg_samples=len(pos_samples),
                                        method='dense')
        
        # 在正边嵌入上运行链接预测器前向传递
        pos_preds = link_predictor(node_embeddings[pos_samples[:, 0]], 
                                    node_embeddings[pos_samples[:, 1]])
        
        # 在负边嵌入上运行链接预测器前向传递
        neg_preds = link_predictor(node_embeddings[neg_samples[0]], 
                                    node_embeddings[neg_samples[1]])

        preds = torch.concat((pos_preds, neg_preds))
        labels = torch.concat((torch.ones_like(pos_preds), 
                                torch.zeros_like(neg_preds)))

        loss = F.binary_cross_entropy(preds, labels)

        loss.backward()
        optimizer.step()

        num_examples = len(pos_preds)
        total_loss += loss.item() * num_examples
        total_examples += num_examples
    
    return total_loss / total_examples
##参数
lr = 0.005 
batch_size = 65536 
epochs = 2  
eval_steps = 5 
optimizer = torch.optim.Adam(list(graphsage_model.parameters()) + list(link_predictor.parameters()),lr=lr)

我们根据 OGB 数据集中提供的验证和测试正负边缘来评估我们的模型:

pos_valid_edges = valid_edges['edge'].to(device)
neg_valid_edges = valid_edges['edge_neg'].to(device)
pos_test_edges = test_edges['edge'].to(device)
neg_test_edges = test_edges['edge_neg'].to(device)

from ogb.linkproppred import Evaluator

evaluator = Evaluator(name = dataset_name)
@torch.no_grad()
def test(graphsage_model, link_predictor, initial_node_embeddings, edge_index, pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size, evaluator, edge_attr=None):
    graphsage_model.eval()
    link_predictor.eval()

    final_node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)

    pos_valid_preds = []
    for pos_samples in DataLoader(pos_valid_edges, batch_size):
        pos_preds = link_predictor(final_node_embeddings[pos_samples[:, 0]], 
                                    final_node_embeddings[pos_samples[:, 1]])
        pos_valid_preds.append(pos_preds.squeeze())
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)
    

    neg_valid_preds = []
    for neg_samples in DataLoader(neg_valid_edges, batch_size):
        neg_preds = link_predictor(final_node_embeddings[neg_samples[:, 0]], 
                                    final_node_embeddings[neg_samples[:, 1]])
        neg_valid_preds.append(neg_preds.squeeze())
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for pos_samples in DataLoader(pos_test_edges, batch_size):
        pos_preds = link_predictor(final_node_embeddings[pos_samples[:, 0]], 
                                    final_node_embeddings[pos_samples[:, 1]])
        pos_test_preds.append(pos_preds.squeeze())
    pos_test_pred = torch.cat(pos_test_preds, dim=0)
    
    neg_test_preds = []
    for neg_samples in DataLoader(neg_test_edges, batch_size):
        neg_preds = link_predictor(final_node_embeddings[neg_samples[:, 0]], 
                                    final_node_embeddings[neg_samples[:, 1]])
        neg_test_preds.append(neg_preds.squeeze())
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    # Calculate Hits@20
    evaluator.K = 20
    valid_hits = evaluator.eval({'y_pred_pos': pos_valid_pred, 'y_pred_neg': neg_valid_pred})
    test_hits = evaluator.eval({'y_pred_pos': pos_test_pred, 'y_pred_neg': neg_test_pred})

    return valid_hits, test_hits
import matplotlib.pyplot as plt

epochs_bar = trange(1, epochs + 1, desc='Loss n/a')

edge_index = ddi_graph.edge_index.to(device)
pos_train_edges = train_edges['edge'].to(device)

losses = []
valid_hits_list = []
test_hits_list = []
for epoch in epochs_bar:
    loss = train(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index, pos_train_edges, optimizer, batch_size)
    losses.append(loss)

    epochs_bar.set_description(f'Loss {loss:0.4f}')

    if epoch % eval_steps == 0:
        valid_hits, test_hits = test(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index, pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size, evaluator)
        print()
        print(f'Epoch: {epoch}, Validation Hits@20: {valid_hits["hits@20"]:0.4f}, Test Hits@20: {test_hits["hits@20"]:0.4f}')
        valid_hits_list.append(valid_hits['hits@20'])
        test_hits_list.append(test_hits['hits@20'])
    else:
        valid_hits_list.append(valid_hits_list[-1] if valid_hits_list else 0)
        test_hits_list.append(test_hits_list[-1] if test_hits_list else 0)

plt.title(dataset.name + ": GraphSAGE")
plt.xlabel("Epoch")
plt.plot(losses, label="Training loss")
plt.plot(valid_hits_list, label="Validation Hits@20")
plt.plot(test_hits_list, label="Test Hits@20")
plt.legend()
plt.show()
输出
Loss 0.4814: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:27<00:00, 13.65s/it]


我自己电脑太拉了就跑了2个epoch,看看能不能跑通代码试一下。
白嫖gpu跑了一下50epoch

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/717666.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-25
下一篇 2022-04-25

发表评论

登录后才能评论

评论列表(0条)

保存