PyTorch 7.保存和加载pytorch模型的方法

PyTorch 7.保存和加载pytorch模型的方法,第1张

PyTorch 7.保存加载pytorch模型的方法

PyTorch 7.保存和加载pytorch模型的方法
  • 保存和加载模型
  • torch.save&torch.load
  • 跨设备保存加载模型

保存和加载模型

python的对象都可以通过torch.save和torch.load函数进行保存和加载

x1 = {"d":"df","dd":"ddf"}
torch.save(x1,'a1.pt')
x2 = torch.load('a1.pt')

下面来谈模型的state_dict(),该函数返回模型的所有参数

class MLP(nn.Module):
	def __init__(self):
		super(MLP,self).__init__()
		self.hidden = nn.Linear(3,2)
		self.act = nn.ReLU()
		self.output = nn.Linear(2,1)
	def forward(self,x):
		a = self.act(self.hidden(x))
		return self.output(a)
net = MLP()
net.state_dict()

输出

OrderedDict([('hidden.weight',
              tensor([[-0.4195,  0.2609,  0.4325],
                      [-0.4031,  0.2078,  0.2077]])),
             ('hidden.bias', tensor([ 0.0755, -0.1408])),
             ('output.weight', tensor([[0.2473, 0.6614]])),
             ('output.bias', tensor([0.6191]))])
torch.save&torch.load
  1. 保存整个模型
    如果选择保存模型,那么可以不需要预先创建模型的实例,可以直接加载模型及其参数
torch.save(net,path)
  1. 保存模型参数
    选择保存模型参数,在加载时需要先创建模型实例
state_dict = net.state_dict()
torch.save(state_dict,path)
  1. 模型finetune
    如果模型训练中不小心中断了,或者需要用该模型去其他模型进行finetune。我们不仅要保存模型参数,还需要保存模型的训练周期及优化器参数。
    这里,我们经常会看到一个叫checkpoint的东东,它其实是一个字典
checkpoint = {
	"model_state_dict":net.state_dict(),
	"optimizer_state_dict":optimizer.state_dict(),
	"epoch":epoch
}
# 保存
torch.save(checkpoint, path_checkpoint)
# 加载
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
跨设备保存加载模型
  1. 在CPU上加载在GPU上训练并保存的模型:
device = torch.device('cpu')
model = MyModel()
model.load_state_dict(torch.load('net_params.pth', map_location=device))
  1. 在GPU上加载在GPU上训练并保存的模型:
device = torch.device('cuda')
model = MyModel()
model.load_state_dict(torch.load('net_params.pth'))
model.to(device)

在这里使用map_location参数不起作用,要使用model.to(torch.device(“cuda”))将模型转换为CUDA优化的模型
数据也要转换到GPU
由于my_tensor.to(device)会返回一个my_tensor在GPU上的副本,它不会覆盖my_tensor

my_tensor = my_tensor.to(device)

存在多个GPU设备时
map_location指定tensor加载的GPU序号

model.load_state_dict(torch.load('net_params.pth'),map_location='cuda:0')

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5593290.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存