- 前言
- 一、Lambda层如何保存
- 1.Lambda层定义
- 1.Lambda层保存
- 2.多进程训练
- 总结
前言
记录:如何在Keras同时使用Lambda自定义层以及多进程训练
一、Lambda层如何保存 1.Lambda层定义
Lambda层可以自定义函数(如Truncation)以及函数传参(如clip_value_min, clip_value_max, zero_thld)来构建网络层,代码示例:
L1_1 = Lambda(lambda x: Truncation(x,clip_value_min,clip_value_max,zero_thld),name='max_min_truncat_{:d}'.format(truncat_cnt))(L1_1); truncat_cnt+=11.Lambda层保存
Lambda一般使用keras.Model.save()函数来存储为’.h5’文件,但要注意该函数只会保存全局import的函数或者类,如果保存import在函数体内部的函数或者类,load_model时会报错:‘str’ object is not callable,代码如下(示例):
#-----------User Files-------------# from get_dataset_utils import get_dataset_pp from get_info_utils import Get_An_Na_Acc from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights def Tfun_Train_LY(start_thld,end_thld): # from user_def_utils import Truncation (will cause error) ... model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train)) print('save model successfully!') ...2.多进程训练
注意对于keras这样的训练库一定要定义在目标进程函数体内,否则会因为多进程访问同样keras对象而报错,代码如下(示例):
#-----------User Files-------------# from get_dataset_utils import get_dataset_pp from get_info_utils import Get_An_Na_Acc from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights def Tfun_Train_LY(start_thld,end_thld): # from user_def_utils import Truncation (will cause error) ... model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train)) print('save model successfully!') ... import multiprocessing import os if __name__ == '__main__': min_tar_num = 0 max_tar_num = 29 process_total = 15 # 设置起飞并行度参数 per_process_num = round((max_tar_num-min_tar_num + 1) / process_total) process_remain_tar_num = (max_tar_num-min_tar_num + 1) - per_process_num * (process_total - 1) #Tfun_Threshold_Acc(min_tar_num, max_tar_num) process_list = [] # 创建进程 try: for i in range(process_total-1): process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (i*per_process_num, (i+1)*per_process_num))) process_list[i].start() if process_total == 1: process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (min_tar_num, max_tar_num+1))) process_list[0].start() else: process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = ((i+1)*per_process_num, (i+1)*per_process_num+process_remain_tar_num))) process_list[i+1].start() # 等待所有进程结束 for prs in process_list: prs.join() except: print ("Error: 无法启动线程")
总结
以上就是今天要讲的内容,本文仅仅简单介绍了Keras中Lambda自定义层保存和多进程训练,仅供参考。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)