欢迎访问个人网络日志🌹🌹知行空间🌹🌹
文章目录
- 0.简介
- 1.BatchNorm1d
- 2.BatchNorm2d
- 3.BatchNorm3d
- 参考资料
Batch Normalization
在训练过程中对网络的输入输出进行归一化,可有效防止梯度爆炸和梯度消失,能加快网络的收敛速度。
y = x − E ( x ) ( V a r ( x ) + ϵ ) γ + β y = \frac{x-E(x)}{\sqrt(Var(x)+\epsilon)}\gamma+\beta y=( Var(x)+ϵ)x−E(x)γ+β
如上式,x
表示的是输入变量,E(x)
和Var(x)
分别表示x
的那每个特征维度在batch size
上所求得的梯度及方差。
ϵ
\epsilon
ϵ是为了防止除以0,通常为1e-5
,
γ
\gamma
γ和
β
\beta
β是可学习的参数,在torch BatchNorm API
中,可通过设置affine=True/False
来设置这两个参数是固定还是可学习的。True
表示可学习,False
表示不可学习,默认
γ
=
1
\gamma=1
γ=1,
β
=
0
\beta=0
β=0。
BatchNorm1d
是对NXC
或NXCXL
维度的向量做Batch Normalization
,N
表示Batch Size
的大小,C
表示数据的维度,L
表示每个维度又有多少维组成。
如上图,表示了一组NXCXL=3X2X3
的数据,
使用BatchNorm1d
后的输出为:
from torch import nn
batch = nn.BatchNorm1d(2, affine=False)
t = torch.tensor([[[7,4,6],[1,2,3]],[[3,4,2],[2,4,6]],[[9,0,7],[3,8,5]]])
t = t.float()
batch(t)
"""
输出为:
tensor([[[ 0.8750, -0.2500, 0.5000],
[-1.3250, -0.8480, -0.3710]],
[[-0.6250, -0.2500, -1.0000],
[-0.8480, 0.1060, 1.0600]],
[[ 1.6250, -1.7500, 0.8750],
[-0.3710, 2.0140, 0.5830]]])
"""
上述的计算过程等价为:
因为affine=False
因此
γ
=
1
,
β
=
0
\gamma=1,\beta=0
γ=1,β=0,期望的计算是单独在每个维度上对Batch
计算的,等价为
在特征维度0上的均值
E
(
x
)
=
7
+
4
+
6
+
3
+
4
+
2
+
9
+
0
+
7
3
×
3
=
4.6667
E(x) = \frac{7+4+6+3+4+2+9+0+7}{3\times3} = 4.6667
E(x)=3×37+4+6+3+4+2+9+0+7=4.6667
同理可计算方差为:‵Var(X) = 2.6667`
tmp = t[:,0,:]
print(tmp.mean())
print(tmp.var(unbiased=False).sqrt())
print((tmp-tmp.mean())/(tmp.var(unbiased=False).sqrt()+1e-5))
"""
Output:
tensor(4.6667)
tensor(2.6667)
tensor([[ 0.8750, -0.2500, 0.5000],
[-0.6250, -0.2500, -1.0000],
[ 1.6250, -1.7500, 0.8750]])
"""
注意在上述计算方差的过程中没有使用Bessel’s correction
贝塞尔校正,除以的是n
而不是n-1
,因此通过这种方式计算的方差是有偏的。上面的结果与BatchNorm1d
的输出是一致的。
from torch import nn
batch = nn.BatchNorm2d(2, affine=False)
img = torch.randint(0, 255, (2,2,3,3))
img = img.float()
print(img)
print(batch(img))
t = img[:,0,:,:]
print(t.mean())
print(t.var().sqrt())
print((t-t.mean())/(t.var(unbiased=False).sqrt()+1e-5))
"""
Output:
tensor([[[[ 97., 163., 130.],
[ 26., 83., 183.],
[165., 108., 242.]],
[[113., 184., 236.],
[159., 223., 247.],
[ 48., 104., 111.]]],
[[[110., 93., 115.],
[237., 168., 120.],
[149., 115., 48.]],
[[117., 22., 43.],
[202., 63., 209.],
[104., 135., 99.]]]])
tensor([[[[-0.6115, 0.5873, -0.0121],
[-1.9012, -0.8658, 0.9506],
[ 0.6236, -0.4117, 2.0223]],
[[-0.3169, 0.7350, 1.5054],
[ 0.3646, 1.3128, 1.6683],
[-1.2798, -0.4502, -0.3465]]],
[[[-0.3754, -0.6842, -0.2846],
[ 1.9315, 0.6781, -0.1938],
[ 0.3330, -0.2846, -1.5016]],
[[-0.2576, -1.6650, -1.3539],
[ 1.0016, -1.0576, 1.1054],
[-0.4502, 0.0091, -0.5243]]]])
tensor(130.6667)
tensor(56.6486)
tensor([[[-0.6115, 0.5873, -0.0121],
[-1.9012, -0.8658, 0.9506],
[ 0.6236, -0.4117, 2.0223]],
[[-0.3754, -0.6842, -0.2846],
[ 1.9315, 0.6781, -0.1938],
[ 0.3330, -0.2846, -1.5016]]])
"""
BatchNorm2d
的输入维度是NCHW
形式的4维变量,计算均值和方差时是以C
为标准逐各通道上计算的,每个通道上有一个均值和方差。在NHW
上进行计算。
batch = nn.BatchNorm3d(2, affine=False)
t = torch.randint(0, 3, (2,2,3,3,3))
t = t.float()
print(batch(t))
tmp = t[:,0,:,:,:]
print(tmp.mean())
print(tmp.var().sqrt())
print((tmp-tmp.mean())/(tmp.var(unbiased=False).sqrt()+1e-5))
参考资料
- 1.torch.nn.BatchNorm
- 2.详解pytorch中nn模块的BatchNorm2d()函数
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)