torch.flatten与torch.nn.flatten的区别

torch.flatten与torch.nn.flatten的区别,第1张

torch.flatten:默认从第0维开始拉伸,变成一个向量.

x = torch.ones((2,3))
print(x)
print(torch.flatten(x))
# 结果
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([1., 1., 1., 1., 1., 1.])

torch.nn.flatten:默认从第1维开始拉伸,变成一个矩阵.

(因为它默认你是在神经网络中使用,而神经网络中第0维表示的是batch_size,所以从第1维开始拉伸)

x = torch.ones(64, 1, 5, 5)

model = nn.Sequential(
    # same卷积
    nn.Conv2d(1, 6, 3, 1, 1),  # (64,6,5,5)
    nn.Flatten()  # (64,6*5*5)
)

output = model(x)
print(output.size())
# 结果
torch.Size([64, 150])

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/916739.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存