mmdetection在bbox head中获取epoch值

mmdetection在bbox head中获取epoch值,第1张

最近需要在模型的bbox head中进行一些epoch相关的 *** 作,比如依据epoch数目更改某个模块的层数。本文主要参考了https://github.com/open-mmlab/mmdetection/issues/7425, 将具体做法进行整理。

1. 新建set_epoch_info_hook

新建mmdetection/mmdet/core/hook/set_epoch_info_hook.py,内容填充如下,并在相应的__Init__.py添加

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook


@HOOKS.register_module()
class SetEpochInfoHook(Hook):
    """Set runner's epoch information to the model."""

    def before_train_epoch(self, runner):
        epoch = runner.epoch
        model = runner.model
        if is_module_wrapper(model):
            model = model.module
        model.set_epoch(epoch)
2. 修改config

mmdetection/configs/my_model/my_model_r50_fpn_1x_coco.py中新增如下行,调用新建的set_epoch_info_hook:

# custom hooks
custom_hooks = [dict(type='SetEpochInfoHook')]
3. 调用

在自己的检测器模型mmdetection/mmdet/models/detectors/my_model.py中新增set_epoch类函数,

 def set_epoch(self, epoch): 
     self.bbox_head.epoch = epoch 

完成以上步骤,就可以在bbox_head中直接通过self.epoch直接获取当前的epoch值。

如果对你有帮助请不要吝啬,点赞收藏评论一键三连,谢谢!

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/727552.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-26
下一篇 2022-04-26

发表评论

登录后才能评论

评论列表(0条)

保存