import warnings warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning) import tensorflow as tf from FeatureGenerator import FeatureGenerator from DataLoad import DataLoad from WideDeep import WDL import numpy as np FLAGS = None is_train = True def main(unused_argv): print('....................start generat feature from json file...................') fg = FeatureGenerator('fg.json') # fg.cross_feature(["user_id","cate_id"], 1000000) print('....................end generat feature from json file...................') # print('....................wide & deep...................') print(fg.features_name) wide_columns = [] for k in fg.features_name['sparse_indicator']: wide_columns.append(fg.add_feature('s',k)) deep_columns = [] for k in fg.features_name['numeric']: deep_columns.append(fg.add_feature('n',k)) # print('....................read input tables...................') train_file,eval_file = FLAGS.tables.split(',') dl = DataLoad({ "train_file":train_file ,"eval_file":eval_file ,"num_epochs" :100 ,'batch_size' :128 ,"_CSV_COLUMNS" :fg._CSV_COLUMNS ,"_CSV_COLUMN_DEFAULTS" :fg._CSV_COLUMN_DEFAULTS ,"worker_hosts" :FLAGS.worker_hosts ,"task_index" :FLAGS.task_index ,"remote_train": FLAGS.remote_train }) # # 保存的checkpoint的序号要大于train_max_steps,否则evaluator无法结束 max_steps = dl.get_cur_max_steps(FLAGS.checkpointDir) model = WDL(wide_columns,deep_columns, { 'checkpointDir':FLAGS.checkpointDir ,'rtp_dfs_path':FLAGS.rtp_dfs_path ,'rtp_model_name':FLAGS.rtp_model_name ,'learning_rate':0.1 ,'max_steps':max_steps ,'save_checkpoints_steps':100 ,'save_summary_steps':10 ,'protocol':FLAGS.protocol }) # print('....................start train model...................') model.train(dl.train_input_fn, dl.eval_input_fn) # print('....................start save model...................') # model.save_model(fg.save_model_feature()) if __name__ == "__main__": tf.app.flags.DEFINE_string("tables", "train.csv,test.csv", "tables info") tf.app.flags.DEFINE_string("job_name", None, "job name: worker or ps") tf.app.flags.DEFINE_integer("task_index", None, "Worker or server index") tf.app.flags.DEFINE_string("worker_hosts", "", "worker hosts") tf.app.flags.DEFINE_string("checkpointDir", '../model', "oss checkpointDir") tf.app.flags.DEFINE_string("protocol", None, "oss checkpointDir") tf.app.flags.DEFINE_string("rtp_dfs_path", "../model_rtp", "rtp_dfs_path") tf.app.flags.DEFINE_string("rtp_model_name", "wdl", "rtp_model_name") tf.app.flags.DEFINE_string("remote_train", "localhost", "rtp_model_name") FLAGS = tf.app.flags.FLAGS tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)