Pytorch学习笔记——fan

Pytorch学习笔记——fan,第1张

文章目录
  • 前言
  • fan_in和fan_out的含义是什么呢?
  • Pytorch中如何计算fan_int和fan_out?
    • 方法一
    • 方法二
  • 小结

前言

在前一篇文章中介绍了kaiming均匀初始化方法,其中有一个mode参数,可以是fan_infan_out,本文研究一下这两个参数的含义。


官方关于这个参数作用的解释:
选择"fan_in"可以保留前向计算中权重方差的大小。



选择"fan_out"将保留后向传播的方差大小。


#kaiming均匀初始化
torch.nn.init.kaiming_uniform_(
    tensor, 
    a=0, 
    mode='fan_in', 
    nonlinearity='leaky_relu'
)
fan_in和fan_out的含义是什么呢?

为了理解fan_infan_out,先要理解什么是fan


词典中的解释是这样的:

disperse or radiate from a central point to cover a wide area:
从一个中心点分散或辐射到一个大范围。


看一个例句:

“the arriving passengers began to fan out through the town in search of lodgings”
到达的乘客开始在城里散开,寻找住处。


既然fan代表从一个中心点分散到一个大范围,那么在神经网络中,这个中心点对应什么呢?

一种简单的理解方式是,这个中心点对应一个层


比如下面这个简单的Linear模型。


上图就代表了一个层,对应一个nn.Linear,包含一个权重矩阵。


关于层。



有人按照绿色节点来算层,有人用中间的权重来算层,其实只是不同的视角,本质都是一样的。



我们观察Linear的实现,里面只有一个权重,并没有所谓节点。



最终训练得到的模型,也只是保留各层的权重。



但在画网络图的时候,节点又是重要的组成部分。


把层作为fan对应的中心点,那么:
这个层的输入就是fan_in
这个层的输出就是fan_out


如果从节点的角度来理解fan_in和fan_out:

fan_in:本层每个节点的输入个数。


答案是4个,因为上一层有4个节点。


fan_out:就是本层的输出节点个数(也就是本层节点数)

答案是6,因为本层有6个节点。


Pytorch中如何计算fan_int和fan_out?

仍然用一个线性层为例来说明。


方法一

使用_calculate_fan_in_and_fan_out同时计算fan_infan_out


>>> linear = nn.Linear(4, 6)
>
>>> linear
Linear(
    in_features=4, 
    out_features=6,
    bias=True)

>>> w = linear.weight
>>> w.shape
torch.Size([6, 4])
# 注意,这里weight的形状
# 对于矩阵左乘,Wx
# 将4维输入x映射为6维输出
# 所以w是 6 * 4
# 如果是矩阵右乘,则相反

>>> print(
... nn.init._calculate_fan_in_and_fan_out(
... linear.weight)
... )
(4, 6)
# fan_in: 4
# fan_out: 6
方法二

单独计算fan_infan_out


>>> w = linear.weight
>>> w.shape
torch.Size([6, 4])

>>> torch.nn.init._calculate_correct_fan(
... w,
... mode='fan_in')
4

从这个简单的例子可以看到,对于Linear模型,fan_in就是输入size,fan_out就是输出size。


小结

本文侧重于理解fan_in和fan_out的含义,通过对fan的原始含义进行分析,发现用层对应fan的中心点是一种理解方式。



文章修改记录
2022.04.06:修改网络层的说明。


欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/571123.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存