- 前言
- fan_in和fan_out的含义是什么呢?
- Pytorch中如何计算fan_int和fan_out?
- 方法一
- 方法二
- 小结
在前一篇文章中介绍了kaiming均匀初始化方法,其中有一个mode
参数,可以是fan_in
和fan_out
,本文研究一下这两个参数的含义。
官方关于这个参数作用的解释:
选择"fan_in"可以保留前向计算中权重方差的大小。
选择"fan_out"将保留后向传播的方差大小。
#kaiming均匀初始化
torch.nn.init.kaiming_uniform_(
tensor,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'
)
fan_in和fan_out的含义是什么呢?
为了理解fan_in
和fan_out
,先要理解什么是fan
。
词典中的解释是这样的:
disperse or radiate from a central point to cover a wide area:
从一个中心点分散或辐射到一个大范围。
看一个例句:
“the arriving passengers began to fan out through the town in search of lodgings”
到达的乘客开始在城里散开,寻找住处。
既然fan代表从一个中心点分散到一个大范围,那么在神经网络中,这个中心点对应什么呢?
一种简单的理解方式是,这个中心点对应一个层
。
比如下面这个简单的Linear模型。
上图就代表了一个层,对应一个nn.Linear,包含一个权重矩阵。
关于层。
有人按照绿色节点来算层,有人用中间的权重来算层,其实只是不同的视角,本质都是一样的。
我们观察Linear的实现,里面只有一个权重,并没有所谓节点。
最终训练得到的模型,也只是保留各层的权重。
但在画网络图的时候,节点又是重要的组成部分。
把层作为fan对应的中心点,那么:
这个层的输入就是fan_in
;
这个层的输出就是fan_out
。
如果从节点的角度来理解fan_in和fan_out:
fan_in
:本层每个节点的输入个数。
答案是4个,因为上一层有4个节点。
fan_out
:就是本层的输出节点个数(也就是本层节点数)
Pytorch中如何计算fan_int和fan_out?答案是6,因为本层有6个节点。
仍然用一个线性层为例来说明。
使用_calculate_fan_in_and_fan_out
同时计算fan_in
和fan_out
。
>>> linear = nn.Linear(4, 6)
>
>>> linear
Linear(
in_features=4,
out_features=6,
bias=True)
>>> w = linear.weight
>>> w.shape
torch.Size([6, 4])
# 注意,这里weight的形状
# 对于矩阵左乘,Wx
# 将4维输入x映射为6维输出
# 所以w是 6 * 4
# 如果是矩阵右乘,则相反
>>> print(
... nn.init._calculate_fan_in_and_fan_out(
... linear.weight)
... )
(4, 6)
# fan_in: 4
# fan_out: 6
方法二
单独计算fan_in
和fan_out
。
>>> w = linear.weight
>>> w.shape
torch.Size([6, 4])
>>> torch.nn.init._calculate_correct_fan(
... w,
... mode='fan_in')
4
从这个简单的例子可以看到,对于Linear模型,fan_in就是输入size,fan_out就是输出size。
本文侧重于理解fan_in和fan_out的含义,通过对fan的原始含义进行分析,发现用层对应fan的中心点是一种理解方式。
文章修改记录
2022.04.06:修改网络层的说明。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)