2021SC@SDUSC
总体的代码结构如下:
经过我们的第二次小组会议、第三次小组会议讨论后,我们确定了关键代码为eval.py、generator.py、lastDataset.py、pargs.py、train.py、vectorize.py。而在第一次讨论后,我负责分析的关键代码为train.py、lastDataset.py、pargs.py,而此篇主要分析train.py中的部分代码。
整个train.py中共有4个函数,分别为update_lr、train、evaluate、main。
首先,我们先分析main(args)函数。
def main(args): try: os.stat(args.save) input("Save File Exists, OverWrite?for no") except: os.mkdir(args.save) ds = dataset(args) args = dynArgs(args,ds) m = model(args) print(args.device) m = m.to(args.device) if args.ckpt: ''' with open(args.save+"/commandLineArgs.txt") as f: clargs = f.read().strip().split("n") argdif =[x for x in sys.argv[1:] if x not in clargs] assert(len(argdif)==2); assert([x for x in argdif if x[0]=='-']==['-ckpt']) ''' cpt = torch.load(args.ckpt) m.load_state_dict(cpt) starte = int(args.ckpt.split("/")[-1].split(".")[0])+1 args.lr = float(args.ckpt.split("-")[-1]) print('ckpt restored') else: with open(args.save+"/commandLineArgs.txt",'w') as f: f.write("n".join(sys.argv[1:])) starte=0 o = torch.optim.SGD(m.parameters(),lr=args.lr, momentum=0.9) # early stopping based on Val Loss lastloss = 1000000 for e in range(starte,args.epochs): print("epoch ",e,"lr",o.param_groups[0]['lr']) train(m,o,ds,args) vloss = evaluate(m,ds,args) if args.lrwarm: update_lr(o,args,e) print("Saving model") torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr'])) if vloss > lastloss: if args.lrdecay: print("decay lr") o.param_groups[0]['lr'] *= 0.5 lastloss = vloss
try: os.stat(args.save) input("Save File Exists, OverWrite?for no") except: os.mkdir(args.save)
根据程序的运行结果来看,运行“python train.py -save S”语句后,会将运行后的结果保存在名为“S”的文件夹中。如果说路径下有一个叫“S”的文件名,它会提示“Save File Exists, OverWrite?”,回车键即可重写进数据。若没有一个叫“S”的文件名,则会自动创建,用来保存数据。
ds = dataset(args) args = dynArgs(args,ds) m = model(args)
后面我们定义了三个变量。dataset、dynArgs、model都是定义的类。首先我们看dataset类:
class dataset: def __init__(self, args):
__init__函数类似于C++中的构造函数,self为原始图实例,args为自定义参数。
args.path = args.datadir + args.data print("Loading Data from ",args.path) self.args = args self.mkVocabs(args) print("Vocab sizes:")
这里是dataset的__init__类函数中的一部分,args.path即原始图的参数的路径为参数的数据路径+数据,在运行过程中显示“Loading Data from”+参数的路径。mkVocabs(args)是类dataset中的一个函数,用于构造文本。
下面我们来看mkVocabs(args)函数(其中一部分)。
def mkVocabs(self,args): args.path = args.datadir + args.data self.INP = data.Field(sequential=True, batch_first=True,init_token="", eos_token=" ",include_lengths=True) self.OUTP = data.Field(sequential=True, batch_first=True,init_token=" ", eos_token=" ",include_lengths=True) self.TGT = data.Field(sequential=True, batch_first=True,init_token=" ", eos_token=" ") self.NERD = data.Field(sequential=True, batch_first=True,eos_token=" ") self.ENT = data.RawField() self.REL = data.RawField() self.SORDER = data.RawField() self.SORDER.is_target = False self.REL.is_target = False self.ENT.is_target = False self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]
首先它对参数的路径进行了定义。
之后的INP、OUTP等等,都是对它各种属性进行赋值 *** 作。而Field和RawField则是定义的两个类。每个数据集由一种或多种类型的数据组成。每种类型的数据都由一个RawField对象表示。RawField对象不采用数据类型和它包含与数据类型应如何处理相关的参数。
RawField类除了__init__函数外,还内含了两个函数,作为数据的预处理函数和处理函数。下篇博客将继续讨论。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)