使用 draw_model()
函数时出现错误:
AttributeError: module 'torch.onnx' has no attribute 'set_training'
错误原因
PyTorch
版本过高造成的, 我安装的是 1.6
版本
解决方案
打开文件:Anaconda\data\envs\torch\lib\site-packages\tensorwatch\model_graph\hiddenlayer\summary_graph.py
找到错误行:
with torch.onnx.set_training(model_clone, False):
将函数改成:
with torch.onnx.select_model_mode_for_export(model_clone, False):
即: set_training
改成 select_model_mode_for_export
即可
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)