【AM-GCN】代码解读之主程序

【AM-GCN】代码解读之主程序,第1张

[前篇]:
初了解(一)



一、导入库

import torch.nn.functional as F #常用函数
import torch.optim as optim     #优化算法
from utils import *
from models import SFGCN        #主模型
from sklearn.metrics import f1_score #计算F1分数,也称为平衡F分数或F测度
import os
import argparse
from config import Config
import torch
import numpy as np

解释说明

  1. torch.optim
  2. f1_score

二、参数读取和设置

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    parse = argparse.ArgumentParser()
    parse.add_argument("-d", "--dataset", help="dataset", type=str, required=True)
    parse.add_argument("-l", "--labelrate", help="labeled data for train per class", type = int, required = True)
    args = parse.parse_args()
    config_file = "./config/" + str(args.labelrate) + str(args.dataset) + ".ini"
    config = Config(config_file)

    cuda = not config.no_cuda and torch.cuda.is_available() #cuda可否能用

    use_seed = not config.no_seed # cuda
    if use_seed:
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)
        if cuda:
            torch.cuda.manual_seed(config.seed)

解释说明

  1. os.environ[“CUDA_VISIBLE_DEVICES”] = "2"设置使用的标号为"2"的显卡。


    os.environ[‘环境变量名称’]=‘环境变量值’ :其中key和value均为string类型

  2. argparse 参数配置的方法,可以打开cmd来设置参数的值,具体使用方法见5中链接。


    1)import argparse 首先导入模块
    ​2)parser = argparse.ArgumentParser() 创建一个解析器
    3)parser.add_argument() 向该解析器中添加你要关注的命令行参数和选项
    4)parser.parse_args() 进行解析

  3. config = Config(config_file) 是调用的config.py中的class类函数并实例化。


    其可以读取 'config_file '文件路径中的参数文件并配置参数。


  4. 'config_file '文件路径,涉及了 标签率args.labelrate 数据集args.dataset.通过原文可以知道是

    三个标签率(即每类20、40、60个标签节点)
    六个数据集(见《初了解一》)

  5. 区别:import argparse 和 import configparser,↙️具体见链接。


    1. 可以理解为都是参数配置的模块,两者并非不可互相替代。


    2. 显著区别是使用方法不同。


      通俗来讲,
      configparser更像是把参数进行分类归纳整理在一个.ini的文件中,通过读取文件的方式获得文件中的各项参数。



      argparse 更像是机器的开关,在开机运作的时候设置各项功能。


6.参数来源以及对照值

对象读取命令读取文件参数值
config.no_cudaconf.getboolean("Model_Setup", "no_cuda")20acmFalse
config.no_seedgetboolean("Model_Setup", "no_seed")20acmFalse
config.seedgetint("Model_Setup", "seed")20acm123
  1. 函数解析(具体见超链接):

np.random.seed()函数:用于生成指定随机数。



torch.manual_seed()函数:CPU生成随机数的种子,方便下次复现实验结果
torch.cuda.manual_seed()函数:固定生成随机数的种子,使得每次运行该 .py 文件时生成的随机数相同

  1. cuda–至结束。


    含义为是否使用cuda以及有可用的cuda,如果使用cuda则设置随机数种子123,否则使用cpu设置随机数种子123.


二、数据读取

    sadj, fadj = load_graph(args.labelrate, config)
    features, labels, idx_train, idx_test = load_data(config)
  • fadj是特征图矩阵 【3025,3025】–> 在acm中有7282个非零值
  • sadj是结构图矩阵 【3025,3025】–>有26256个非零值
  • features是特征向量—>(3025, 1870)
  • labels是标签向量 —>(3025,1)
  • idx_train是训练数据索引 -->(1000,)
  • idx_test是测试数据索引 -->(60,)
    以上数据结构以acm数据为例。


    且都是torch中的结构



三、模型准备

  1. 模型实例化
  2. 如果cuda可用,将数据放置到cuda上。


   model = SFGCN(nfeat = config.fdim,
              nhid1 = config.nhid1,
              nhid2 = config.nhid2,
              nclass = config.class_num,
              n = config.n,
              dropout = config.dropout)
    if cuda:
        model.cuda()
        features = features.cuda()
        sadj = sadj.cuda()
        fadj = fadj.cuda()
        labels = labels.cuda()
        idx_train = idx_train.cuda()
        idx_test = idx_test.cuda()
    optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

四、模型训练

    acc_max = 0
    f1_max = 0
    epoch_max = 0
    for epoch in range(config.epochs):
        loss, acc_test, macro_f1, emb = train(model, epoch)
        if acc_test >= acc_max:
            acc_max = acc_test
            f1_max = macro_f1
            epoch_max = epoch
    print('epoch:{}'.format(epoch_max),
          'acc_max: {:.4f}'.format(acc_max),
          'f1_max: {:.4f}'.format(f1_max))

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/568835.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存