Keras中Lambda自定义层保存和多进程训练

Keras中Lambda自定义层保存和多进程训练,第1张

Keras中Lambda自定义层保存和多进程训练

Keras中Lambda自定义层的保存
  • 前言
  • 一、Lambda层如何保存
    • 1.Lambda层定义
    • 1.Lambda层保存
    • 2.多进程训练
  • 总结


前言

记录:如何在Keras同时使用Lambda自定义层以及多进程训练


一、Lambda层如何保存 1.Lambda层定义

Lambda层可以自定义函数(如Truncation)以及函数传参(如clip_value_min, clip_value_max, zero_thld)来构建网络层,代码示例:

L1_1 = Lambda(lambda x: Truncation(x,clip_value_min,clip_value_max,zero_thld),name='max_min_truncat_{:d}'.format(truncat_cnt))(L1_1);
truncat_cnt+=1
1.Lambda层保存

Lambda一般使用keras.Model.save()函数来存储为’.h5’文件,但要注意该函数只会保存全局import的函数或者类,如果保存import在函数体内部的函数或者类,load_model时会报错:‘str’ object is not callable,代码如下(示例):

#-----------User Files-------------#
from get_dataset_utils import get_dataset_pp
from get_info_utils import Get_An_Na_Acc
from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights

def Tfun_Train_LY(start_thld,end_thld):
	# from user_def_utils import Truncation (will cause error)
	...
	model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train))
	print('save model successfully!')
	...
2.多进程训练

注意对于keras这样的训练库一定要定义在目标进程函数体内,否则会因为多进程访问同样keras对象而报错,代码如下(示例):

#-----------User Files-------------#
from get_dataset_utils import get_dataset_pp
from get_info_utils import Get_An_Na_Acc
from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights

def Tfun_Train_LY(start_thld,end_thld):
	# from user_def_utils import Truncation (will cause error)
	...
	model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train))
	print('save model successfully!')
	...

import multiprocessing
import os

if __name__ == '__main__':
    min_tar_num = 0
    max_tar_num = 29
    process_total = 15 # 设置起飞并行度参数
    per_process_num = round((max_tar_num-min_tar_num + 1) / process_total)
    process_remain_tar_num = (max_tar_num-min_tar_num + 1) - per_process_num * (process_total - 1)

    #Tfun_Threshold_Acc(min_tar_num, max_tar_num)
    
    process_list = []
    # 创建进程
    try:
        for i in range(process_total-1):
            process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (i*per_process_num, (i+1)*per_process_num)))
            process_list[i].start()
        if process_total == 1:
            process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (min_tar_num, max_tar_num+1)))
            process_list[0].start()
        else:
            process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = ((i+1)*per_process_num, (i+1)*per_process_num+process_remain_tar_num)))
            process_list[i+1].start()

        # 等待所有进程结束
        for prs in process_list:
            prs.join()
    except:
        print ("Error: 无法启动线程")

总结

以上就是今天要讲的内容,本文仅仅简单介绍了Keras中Lambda自定义层保存和多进程训练,仅供参考。

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5443749.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-11
下一篇 2022-12-11

发表评论

登录后才能评论

评论列表(0条)

保存