DeepCTR-Torch 如何保存模型

DeepCTR-Torch 如何保存模型,第1张

官网给出的示例:

DeepCTR Documentation, Release 0.9.0
from tensorflow.python.keras.models import save_model,load_model
model = DeepFM()
save_model(model, 'DeepFM.h5')# save_model, same as before
from deepctr.layers import custom_objects
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter

实际执行的时候会报错:

AttributeError: 'DeepFM' object has no attribute 'outputs'

解决办法:

采用 torch的模型保存办法,下述示例是 Deep-torch 的示例文件 run_classification_criteo.py

from deepctr_torch.models import *

#创建model1

model = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')

#model 处理

#...

保存以及读取模型:

#model 保存

import torch

torch.save(model.state_dict(),"a.txt")

model2 = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')

model2.load_state_dict(torch.load("a.txt"))

pred_ans2 = model2.predict(test_model_input, batch_size=256)

参考:

https://deepctr-doc.readthedocs.io/_/downloads/en/latest/pdf/

https://www.pytorch123.com/ThirdSection/SaveModel/

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/742714.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-28
下一篇 2022-04-28

发表评论

登录后才能评论

评论列表(0条)

保存