torch.flatten
:默认从第0维开始拉伸,变成一个向量
.
x = torch.ones((2,3))
print(x)
print(torch.flatten(x))
# 结果
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([1., 1., 1., 1., 1., 1.])
torch.nn.flatten
:默认从第1维开始拉伸,变成一个矩阵
.
(因为它默认你是在神经网络中使用,而神经网络中
第0维表示的是batch_size
,所以从第1维开始拉伸)
x = torch.ones(64, 1, 5, 5)
model = nn.Sequential(
# same卷积
nn.Conv2d(1, 6, 3, 1, 1), # (64,6,5,5)
nn.Flatten() # (64,6*5*5)
)
output = model(x)
print(output.size())
# 结果
torch.Size([64, 150])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)