官网给出的示例:
DeepCTR Documentation, Release 0.9.0
from tensorflow.python.keras.models import save_model,load_model
model = DeepFM()
save_model(model, 'DeepFM.h5')# save_model, same as before
from deepctr.layers import custom_objects
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter
实际执行的时候会报错:
AttributeError: 'DeepFM' object has no attribute 'outputs'
解决办法:
采用 torch的模型保存办法,下述示例是 Deep-torch 的示例文件 run_classification_criteo.py
from deepctr_torch.models import *
#创建model1
model = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')
#model 处理
#...
保存以及读取模型:
#model 保存
import torch
torch.save(model.state_dict(),"a.txt")
model2 = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')
model2.load_state_dict(torch.load("a.txt"))
pred_ans2 = model2.predict(test_model_input, batch_size=256)
参考:
https://deepctr-doc.readthedocs.io/_/downloads/en/latest/pdf/
https://www.pytorch123.com/ThirdSection/SaveModel/
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)