WDL-训练模型

WDL-训练模型,第1张

WDL-训练模型
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)

import tensorflow as tf
from FeatureGenerator import FeatureGenerator
from DataLoad import DataLoad
from WideDeep import WDL
import numpy as np
FLAGS = None

is_train = True

def main(unused_argv):

    print('....................start generat feature from json file...................')
    fg = FeatureGenerator('fg.json')


    # fg.cross_feature(["user_id","cate_id"], 1000000)
    print('....................end generat feature from json file...................')


    # print('....................wide & deep...................')
    print(fg.features_name)
    wide_columns = []
    for k in fg.features_name['sparse_indicator']:
        wide_columns.append(fg.add_feature('s',k))

    deep_columns = []
    for k in fg.features_name['numeric']:
        deep_columns.append(fg.add_feature('n',k))


    # print('....................read input tables...................')
    train_file,eval_file = FLAGS.tables.split(',')

    dl = DataLoad({
        "train_file":train_file
        ,"eval_file":eval_file
        ,"num_epochs" :100
        ,'batch_size' :128
        ,"_CSV_COLUMNS" :fg._CSV_COLUMNS
        ,"_CSV_COLUMN_DEFAULTS" :fg._CSV_COLUMN_DEFAULTS
        ,"worker_hosts" :FLAGS.worker_hosts
        ,"task_index" :FLAGS.task_index
        ,"remote_train": FLAGS.remote_train
    })


    # # 保存的checkpoint的序号要大于train_max_steps,否则evaluator无法结束
    max_steps = dl.get_cur_max_steps(FLAGS.checkpointDir)
    
    model = WDL(wide_columns,deep_columns, {
        'checkpointDir':FLAGS.checkpointDir
        ,'rtp_dfs_path':FLAGS.rtp_dfs_path
        ,'rtp_model_name':FLAGS.rtp_model_name
        ,'learning_rate':0.1
        ,'max_steps':max_steps
        ,'save_checkpoints_steps':100
        ,'save_summary_steps':10
        ,'protocol':FLAGS.protocol
    })
    

    # print('....................start train model...................')
    model.train(dl.train_input_fn, dl.eval_input_fn)

    # print('....................start save model...................')
    # model.save_model(fg.save_model_feature())




if __name__ == "__main__":

    tf.app.flags.DEFINE_string("tables", "train.csv,test.csv", "tables info")
    tf.app.flags.DEFINE_string("job_name", None, "job name: worker or ps")
    tf.app.flags.DEFINE_integer("task_index", None, "Worker or server index")
    tf.app.flags.DEFINE_string("worker_hosts", "", "worker hosts")
    tf.app.flags.DEFINE_string("checkpointDir", '../model', "oss checkpointDir")
    tf.app.flags.DEFINE_string("protocol", None, "oss checkpointDir")
    tf.app.flags.DEFINE_string("rtp_dfs_path", "../model_rtp", "rtp_dfs_path")
    tf.app.flags.DEFINE_string("rtp_model_name", "wdl", "rtp_model_name")
    tf.app.flags.DEFINE_string("remote_train", "localhost", "rtp_model_name")
    FLAGS = tf.app.flags.FLAGS
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5658804.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存