最近需要在模型的bbox head中进行一些epoch相关的 *** 作,比如依据epoch数目更改某个模块的层数。本文主要参考了https://github.com/open-mmlab/mmdetection/issues/7425, 将具体做法进行整理。
1. 新建set_epoch_info_hook新建mmdetection/mmdet/core/hook/set_epoch_info_hook.py
,内容填充如下,并在相应的__Init__.py添加
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class SetEpochInfoHook(Hook):
"""Set runner's epoch information to the model."""
def before_train_epoch(self, runner):
epoch = runner.epoch
model = runner.model
if is_module_wrapper(model):
model = model.module
model.set_epoch(epoch)
2. 修改config
在mmdetection/configs/my_model/my_model_r50_fpn_1x_coco.py
中新增如下行,调用新建的set_epoch_info_hook
:
# custom hooks
custom_hooks = [dict(type='SetEpochInfoHook')]
3. 调用
在自己的检测器模型mmdetection/mmdet/models/detectors/my_model.py
中新增set_epoch
类函数,
def set_epoch(self, epoch):
self.bbox_head.epoch = epoch
完成以上步骤,就可以在bbox_head中直接通过self.epoch
直接获取当前的epoch值。
如果对你有帮助请不要吝啬,点赞收藏评论一键三连,谢谢!
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)