每天讲解一点PyTorch 【15】model.load

每天讲解一点PyTorch 【15】model.load,第1张

每天讲解一点PyTorch 【15】model.load

今天我们讲解:

state_dict = torch.load('checkpoint.pt')
#或者
state_dict = torch.load('checkpoint.pth') #torch.load加载**模型参数**
model.load_state_dict(state_dict) #把模型参数加载到模型中

model.cuda()
model.eval() #model.eval()关闭Batch Normalization和Dropout层
#加载模型结构和模型参数
model = torch.load(path)
output = model(x)

torch.save(model.state_dict(), ‘checkpoint.pt’) #仅保存模型参数
torch.save(model,‘checkpoint.pt’) #保存模型结构和模型参数

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5495682.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-12
下一篇 2022-12-12

发表评论

登录后才能评论

评论列表(0条)

保存