今天我们讲解:
state_dict = torch.load('checkpoint.pt') #或者 state_dict = torch.load('checkpoint.pth') #torch.load加载**模型参数** model.load_state_dict(state_dict) #把模型参数加载到模型中 model.cuda() model.eval() #model.eval()关闭Batch Normalization和Dropout层
#加载模型结构和模型参数 model = torch.load(path) output = model(x)
torch.save(model.state_dict(), ‘checkpoint.pt’) #仅保存模型参数
torch.save(model,‘checkpoint.pt’) #保存模型结构和模型参数
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)