知识图到文本的生成——贰

知识图到文本的生成——贰,第1张

知识图到文本的生成——贰

2021SC@SDUSC

总体的代码结构如下:

经过我们的第二次小组会议、第三次小组会议讨论后,我们确定了关键代码为eval.py、generator.py、lastDataset.py、pargs.py、train.py、vectorize.py。而在第一次讨论后,我负责分析的关键代码为train.py、lastDataset.py、pargs.py,而此篇主要分析train.py中的部分代码。

整个train.py中共有4个函数,分别为update_lr、train、evaluate、main。

首先,我们先分析main(args)函数。

def main(args):
  try:
    os.stat(args.save)
    input("Save File Exists, OverWrite?  for no")
  except:
    os.mkdir(args.save)
  ds = dataset(args)
  args = dynArgs(args,ds)
  m = model(args)
  print(args.device)
  m = m.to(args.device)
  if args.ckpt:
    '''
    with open(args.save+"/commandLineArgs.txt") as f:
      clargs = f.read().strip().split("n") 
      argdif =[x for x in sys.argv[1:] if x not in clargs]
      assert(len(argdif)==2); 
      assert([x for x in argdif if x[0]=='-']==['-ckpt'])
    '''
    cpt = torch.load(args.ckpt)
    m.load_state_dict(cpt)
    starte = int(args.ckpt.split("/")[-1].split(".")[0])+1
    args.lr = float(args.ckpt.split("-")[-1])
    print('ckpt restored')
  else:
    with open(args.save+"/commandLineArgs.txt",'w') as f:
      f.write("n".join(sys.argv[1:]))
    starte=0
  o = torch.optim.SGD(m.parameters(),lr=args.lr, momentum=0.9)

  # early stopping based on Val Loss
  lastloss = 1000000
  
  for e in range(starte,args.epochs):
    print("epoch ",e,"lr",o.param_groups[0]['lr'])
    train(m,o,ds,args)
    vloss = evaluate(m,ds,args)
    if args.lrwarm:
      update_lr(o,args,e)
    print("Saving model")
    torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
    if vloss > lastloss:
      if args.lrdecay:
        print("decay lr")
        o.param_groups[0]['lr'] *= 0.5
    lastloss = vloss
  try:
    os.stat(args.save)
    input("Save File Exists, OverWrite?  for no")
  except:
    os.mkdir(args.save)

 根据程序的运行结果来看,运行“python train.py -save S”语句后,会将运行后的结果保存在名为“S”的文件夹中。如果说路径下有一个叫“S”的文件名,它会提示“Save File Exists, OverWrite?”,回车键即可重写进数据。若没有一个叫“S”的文件名,则会自动创建,用来保存数据。

  ds = dataset(args)
  args = dynArgs(args,ds)
  m = model(args)

后面我们定义了三个变量。dataset、dynArgs、model都是定义的类。首先我们看dataset类:

class dataset:

  def __init__(self, args):

__init__函数类似于C++中的构造函数,self为原始图实例,args为自定义参数

    args.path = args.datadir + args.data
    print("Loading Data from ",args.path)
    self.args = args
    self.mkVocabs(args)
    print("Vocab sizes:")

这里是dataset的__init__类函数中的一部分,args.path即原始图的参数的路径为参数的数据路径+数据,在运行过程中显示“Loading Data from”+参数的路径。mkVocabs(args)是类dataset中的一个函数,用于构造文本。

下面我们来看mkVocabs(args)函数(其中一部分)。

  def mkVocabs(self,args):
    args.path = args.datadir + args.data
    self.INP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.OUTP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.TGT = data.Field(sequential=True, batch_first=True,init_token="", eos_token="")
    self.NERD = data.Field(sequential=True, batch_first=True,eos_token="")
    self.ENT = data.RawField()
    self.REL = data.RawField()
    self.SORDER = data.RawField()
    self.SORDER.is_target = False
    self.REL.is_target = False 
    self.ENT.is_target = False 
    self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]

首先它对参数的路径进行了定义。

之后的INP、OUTP等等,都是对它各种属性进行赋值 *** 作。而Field和RawField则是定义的两个类。每个数据集由一种或多种类型的数据组成。每种类型的数据都由一个RawField对象表示。RawField对象不采用数据类型和它包含与数据类型应如何处理相关的参数。

RawField类除了__init__函数外,还内含了两个函数,作为数据的预处理函数和处理函数。下篇博客将继续讨论。

 

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/3972864.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-10-21
下一篇 2022-10-21

发表评论

登录后才能评论

评论列表(0条)

保存