在使用pytorch框架训练深度学习网络的时候,我们可以很方便地使用torch.save()方法对训练过程中的网络参数等信息进行保存。比如这里,我们保存成的文件格式为pth.tar(如下图所示),咋一看以为是一个压缩包,需要先解压一下。其实不用,直接使用torch.load就可以了。下面讲一下其数据查看方法。
首先,在这些pth.tar文件的目录下打开一个终端,执行:
python
进入python命令行模式。比如我想查看net_5000_checkpoint.pth.tar文件里的内容,那么,依次执行如下命令:
import torch checkpoint = torch.load('net_5000_checkpoint.pth.tar') #这里checkpoint的类型其实是一个字典 print(checkpoint.keys()) #输出该文件里保存的内容的keys #比如我这里的输出为:dict_keys(['epoch', 'state_dict']) #表示这个文件里保存了名为epoch和state_dict这两个内容 print(checkpoint['epoch']) #查看key为epoch的内容 #我这里输出为5001
此外,上述 *** 作是可以反复进行的。就比如,我这里state_dict又是一个字典,它的里头还包含有很多不同的项,那么,可以这样:
state_dict = checkpoint['state_dict'] print(state_dict.keys()) #输出state_dict的keys #我这里输出有点多,形如:odict_keys(['feature_extraction.firstconv.0.0.weight', 'feature_extraction.firstconv.0.1.weight', 'feature_extraction.firstconv.0.1.bias', ...) print(state_dict['feature_extraction.firstconv.0.0.weight']) #查看key名为'feature_extraction.firstconv.0.0.weight'的内容
以上。其实输出的内容和结构都是你在save时候自己写入的内容,而查看keys目的就是为了看一下保存了哪些项,名称是啥。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)