我在他们的github仓库中找到了此页面,我将内容粘贴在这里。
推荐的模型保存方法
序列化和还原模型有两种主要方法。
第一个(推荐)仅保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后再:
the_model = TheModelClass(*args, **kwargs)the_model.load_state_dict(torch.load(PATH))
第二个保存并加载整个模型:
torch.save(the_model, PATH)
然后再:
the_model = torch.load(PATH)
但是,在这种情况下,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)