Pytorch基础 *** 作 —— 7. torch.cat 拼接 *** 作

Pytorch基础 *** 作 —— 7. torch.cat 拼接 *** 作,第1张

Pytorch基础 *** 作 —— 7. torch.cat 拼接 *** 作

文章目录
  • torch.cat
  • 例程
    • 低维度时的拼接
    • 高维度时的拼接
  • 验证

torch.cat

张量拼接是非常常见的 *** 作,以OpenCV为例,有时候我们需要把彩色图片(通常为3通道数据)分别进行处理,然后再重新组合在一起,生成新的图片。对于类似的框架来说,也提供了类似的函数。

现在,让我们来看看Torch的张量拼接函数的原型:

    torch.cat(tensors, dim=0, *, out=None) -> Tensor
  • tensors,通常是一组张量,要求大小维度相同,否则会导致拼接失败。
  • dim,是拼接的方向,默认是0.
例程 低维度时的拼接

这个函数本质上并不难理解,但是唯一比较麻烦的就是dim,也就是轴方向,这是来自Numpy的概念,我觉得对于这个概念最好的理解,还是直接看源码最好。

>>> import torch
>>> ones = torch.ones(4, 5)
>>> zeros = torch.zeros(4, 5)

>>> torch.cat((ones, zeros), dim=0)
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
        
>>> torch.cat((ones, zeros), dim=1)
tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]])
        
>>> torch.cat((ones, zeros), dim=2)
Traceback (most recent call last):
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)        

可以看到,对于低维度的拼接时,torch.cat 仅支持把两个张量按列、行方向进行拼接。那么对于高维度的拼接,比如三维度时,又是怎样的呢?

高维度时的拼接
>>> ones = torch.ones(2, 3, 4, 5)
>>> zeros = torch.zeros(2, 3, 4, 5)

>>> cat1 = torch.cat((ones, zeros), dim=0)
>>> cat1.shape
torch.Size([4, 3, 4, 5])

>>> cat2 = torch.cat((ones, zeros), dim=1)
>>> cat2.shape
torch.Size([2, 6, 4, 5])

>>> cat3 = torch.cat((ones, zeros), dim=2)
>>> cat3.shape
torch.Size([2, 3, 8, 5])

>>> cat4 = torch.cat((ones, zeros), dim=3)
>>> cat4.shape
torch.Size([2, 3, 4, 10])

从这个例子可以很明显的看出,cat *** 作的dim,是依据张量围度从左往右进行计算的。所以很多网上所说的dim=0是沿着X轴、dim=1是沿着Y轴,dim=2是沿着Z轴这种说法是十分不准确的。

所以更准确的说法应该是:dim,指定 *** 作沿着张量的第N位执行指令。 对于上面这个例子来说,执行cat *** 作时,dim=0,即指定以张量维度第0位,执行拼接 *** 作。

验证

为了证明结果,这里执行一个小程序片段,我们分别打印 cat1 和 cat4,看看同样的执行顺序会分别输出什么内容

dim0, dim1, dim2, dim3 = cat1.shape
for i in range(dim0):
    print("i:", i, end=" ")
    for j in range(dim1):
        print("j:", j)
        for k in range(dim2):
            for l in range(dim3):
                print(cat1[i, j, k, l].item(), end=" ")
            print("")
        print("")

print("-----------------------------")

dim0, dim1, dim2, dim3 = cat4.shape
for i in range(dim0):
    print("i:", i, end=" ")
    for j in range(dim1):
        print("j:", j)
        for k in range(dim2):
            for l in range(dim3):
                print(cat4[i, j, k, l].item(), end=" ")
            print("")
        print("")

i: 0 j: 0
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 1
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 2
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

i: 1 j: 0
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 1
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 2
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

i: 2 j: 0
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 1
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 2
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

i: 3 j: 0
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 1
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 2
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

-----------------------------
i: 0 j: 0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 1
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 2
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

i: 1 j: 0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 1
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 2
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/3973771.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-10-21
下一篇 2022-10-21

发表评论

登录后才能评论

评论列表(0条)

保存