PReLU激活函数,内部源码实现
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None: self.num_parameters = num_parameters super(PReLU, self).__init__() self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) def forward(self, input: Tensor) -> Tensor: return F.prelu(input, self.weight)
def prelu(input, weight): # type: (Tensor, Tensor) -> Tensor r"""prelu(input, weight) -> Tensor Applies element-wise the function :math:`text{PReLU}(x) = max(0,x) + text{weight} * min(0,x)` where weight is a learnable parameter. See :class:`~torch.nn.PReLU` for more details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function(prelu, (input,), input, weight) return torch.prelu(input, weight)
上边的公式搞到typora里就是如下形式:
weight这里是个可学习的参数,调用的时候nn.PReLU(planes)只输入了通道数,实现里将weight参数初始化为了0.25.
这和我们熟知的PRelu公式是一致的,
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)