pytorch 23 hook的使用与介绍 及基于hook实现即插即用的DropBlock

pytorch 23 hook的使用与介绍 及基于hook实现即插即用的DropBlock,第1张

Hook(钩子):字面意思就是勾住一个函数,在其执行前修改输入数据或其他 *** 作,或者在其执行后修改输出输出或其他 *** 作。


通过hook *** 作可以锁定一个layer对象(model中的模块)的生命周期,监视其执行状态和执行结果。


在pytorch中,提供丰富的hook api,让我们可以监听并修改tensor在模型forword中的状态。


在本博文中,基于hook *** 作实现了即插即用的Dropout *** 作,支持dorpblock、dropout2d等系列Dropout *** 作。


在pytorch中layer的生命周期可以简单的理解为:定义->初始化->前向传播->反向传播->销毁。


其中前向传播和反向传播是其执行流程中最重要的部分,通过对前向传播和反向传播进行hook *** 作,可以获取数据在模型中的执行状态(比如实现CAM)并进行修改 *** 作(比如实现梯度裁剪)。


Grad-CAM:基于梯度的类别响应特征可视化。


通过hook *** 作获取相应layer的forward流程中feature map的output和backward流程中grad的output,然后将feature_map_output与backward_grad_output中相应的元素相乘,然后实现类别响应特征可视化。


梯度裁剪:对layer的backward流程中grad的output的值进行约束,使其不能大于特定值从而导致梯度爆炸。


1、pytorch中的hook接口

pytorch针对Model、modules.module、ScriptModule和Tensor 4种模块分别提供了相应的hook接口。


其中针对于Tensor只提供了一个用于反向传播的hook,因为tensor的前向传播流程对用户是完全可见的。


针对于Model、modules.module、ScriptModul

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/571805.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存