torch.optim.Optimizer

torch.optim.Optimizer,第1张

torch.optim.Optimizer base class 理解
  • torch.optim.Optimizer
    • 参数理解
      • .params
      • .defaults
    • 类的属性
    • 类的方法
      • 1.Optimizer.state_dict
      • 2.Optimizer.load_state_dict
      • 3. Optimizer.zero_grad
      • 3. Optimizer.add_param_group(self, param_group)

torch.optim.Optimizer

torch.optim.Optimizer(params, defaults)

参数理解 .params
param_groups[group_name] = {
                "lr_scale": this_scale,
                "weight_decay": this_decay,
                "params": [],
            }
##每个param_group都有一个名字,并且里面又优化器的计算子的参数和模型的参数
## 这个例子里面就有 lr_scale 学习率按网络层的衰减或递增,就是越底层的神经网络,学习率越小,越靠近底层学习率越低

每个param_group都有一个名字,并且里面又优化器的计算子的参数和模型的参数
这个例子里面就有 lr_scale 学习率按网络层的衰减或递增,就是越底层的神经网络,学习率越小,越靠近底层学习率越低
必须是(iterable)可迭代的torch.Tensor或者dict类型,确定模型中哪些参数是需要优化的,传入优化器的params是模型中需要优化的参数

.defaults
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)

是一个字典类型的参数,其中包含一些优化options的默认值,比如学习率、动量、权重衰减(在其他parameter group里面这些参数没有实例化的时候,可以在defaluts里面打包成字典融入optimizer)

类的属性
属性名称初始化
self.defaultsself.defaults=defaults
self._hook_for_profile()
self.stateself.state = defaultdict(dict)
self.param_groupsself.param_groups = []
self.add_param_groupself.add_param_group(param_group)
类的方法 1.Optimizer.state_dict
方法作用
Optimizer.state_dict返回dict类型的优化器的state

源码阅读:

python
def state_dict(self):
        r"""Returns the state of the optimizer as a :class:`dict`.

        It contains two entries:

        * state - a dict holding current optimization state. Its content
            differs between optimizer classes.
        * param_groups - a list containing all parameter groups where each
            parameter group is a dict
        """
        # Save order indices instead of Tensors
        param_mappings = {}
        start_index = 0

        def pack_group(group):
            nonlocal start_index
            packed = {k: v for k, v in group.items() if k != 'params'}
            param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
                                   if id(p) not in param_mappings})
            packed['params'] = [param_mappings[id(p)] for p in group['params']]
            start_index += len(packed['params'])
            return packed
        param_groups = [pack_group(g) for g in self.param_groups]
        # Remap state to use order indices as keys
        packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
                        for k, v in self.state.items()}
        return {
            'state': packed_state,
            'param_groups': param_groups,
        }
2.Optimizer.load_state_dict
方法作用
Optimizer.load_state_dict用传入的state_dict来update原来optimizer的state

源码:

 def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Args:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        # Validate the state_dict
        groups = self.param_groups
        saved_groups = state_dict['param_groups']

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of "
                             "parameter groups")
        param_lens = (len(g['params']) for g in groups)
        saved_lens = (len(g['params']) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group "
                             "that doesn't match the size of optimizer's group")

        # Update the state
        id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g['params'] for g in saved_groups)),
                      chain.from_iterable((g['params'] for g in groups)))}
		def cast(param, value):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                # Floating-point types are a bit special here. They are the only ones
                # that are assumed to always match the type of params.
                if param.is_floating_point():
                    value = value.to(param.dtype)
                value = value.to(param.device)
                return value
            elif isinstance(value, dict):
                return {k: cast(param, v) for k, v in value.items()}
            elif isinstance(value, container_abcs.Iterable):
                return type(value)(cast(param, v) for v in value)
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        	state = defaultdict(dict)
        	for k, v in state_dict['state'].items():
            	if k in id_map:
               	 	param = id_map[k]
                	state[param] = cast(param, v)
            	else:
                	state[k] = v

        # Update parameter groups, setting their 'params' value
        def update_group(group, new_group):
            	new_group['params'] = group['params']
            	return new_group
        	param_groups = [
            	update_group(g, ng) for g, ng in 	zip(groups, saved_groups)]
        	self.__setstate__({'state': state, 'param_groups': param_groups})

首先深度拷贝state_dict
state_dict = deepcopy(state_dict)
saved_groups = state_dict[‘param_groups’]
然后获取优化器的param_groups
groups = self.param_groups
然后就是更新state

# Update the state
id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g['params'] for g in saved_groups)),
                      chain.from_iterable((g['params'] for g in groups)))}
3. Optimizer.zero_grad
def zero_grad(self, set_to_none: bool = False):
        r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.

        Args:
            set_to_none (bool): instead of setting to zero, set the grads to None.
                This will in general have lower memory footprint, and can modestly improve performance.
                However, it changes certain behaviors. For example:
                1. When the user tries to access a gradient and perform manual ops on it,
                a None attribute or a Tensor full of 0s will behave differently.
                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
                are guaranteed to be None for params that did not receive a gradient.
                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
                (in one case it does the step with a gradient of 0 and in the other it skips
                the step altogether).
        """
        if not hasattr(self, "_zero_grad_profile_name"):
            self._hook_for_profile()
        with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is not None:
                        if set_to_none:
                            p.grad = None
                        else:
                            if p.grad.grad_fn is not None:
                                p.grad.detach_()
                            else:
                                p.grad.requires_grad_(False)
                            p.grad.zero_()

3. Optimizer.add_param_group(self, param_group)

主要可以在fine tune 一个 pre-trained network 的时候,把一些冻结的层重新可训练

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

def add_param_group(self, param_group):
        r"""Add a param group to the :class:`Optimizer` s `param_groups`.

        This can be useful when fine tuning a pre-trained network as frozen layers can be made
        trainable and added to the :class:`Optimizer` as training progresses.

        Args:
            param_group (dict): Specifies what Tensors should be optimized along with group
            specific optimization options.
        """
        assert isinstance(param_group, dict), "param group must be a dict"

        params = param_group['params']
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                            'the ordering of tensors in sets will change between runs. Please use a list instead.')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer can only optimize Tensors, "
                                "but one of the params is " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")

        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                 name)
            else:
                param_group.setdefault(name, default)

        params = param_group['params']
        if len(params) != len(set(params)):
            warnings.warn("optimizer contains a parameter group with duplicate parameters; "
                          "in future, this will cause an error; "
                          "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError("some parameters appear in more than one parameter group")

        self.param_groups.append(param_group)

关键代码:

##########################
params = param_group['params']

#########################
for name, default in self.defaults.items():
	param_group.setdefault(name, default)
######################
param_set = set()
for group in self.param_groups:
      param_set.update(set(group['params']))
#######################
self.param_groups.append(param_group)

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/568336.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存