- torch.cat
- 例程
- 低维度时的拼接
- 高维度时的拼接
- 验证
张量拼接是非常常见的 *** 作,以OpenCV为例,有时候我们需要把彩色图片(通常为3通道数据)分别进行处理,然后再重新组合在一起,生成新的图片。对于类似的框架来说,也提供了类似的函数。
现在,让我们来看看Torch的张量拼接函数的原型:
torch.cat(tensors, dim=0, *, out=None) -> Tensor
- tensors,通常是一组张量,要求大小维度相同,否则会导致拼接失败。
- dim,是拼接的方向,默认是0.
这个函数本质上并不难理解,但是唯一比较麻烦的就是dim,也就是轴方向,这是来自Numpy的概念,我觉得对于这个概念最好的理解,还是直接看源码最好。
>>> import torch >>> ones = torch.ones(4, 5) >>> zeros = torch.zeros(4, 5) >>> torch.cat((ones, zeros), dim=0) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) >>> torch.cat((ones, zeros), dim=1) tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]) >>> torch.cat((ones, zeros), dim=2) Traceback (most recent call last): IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
可以看到,对于低维度的拼接时,torch.cat 仅支持把两个张量按列、行方向进行拼接。那么对于高维度的拼接,比如三维度时,又是怎样的呢?
高维度时的拼接>>> ones = torch.ones(2, 3, 4, 5) >>> zeros = torch.zeros(2, 3, 4, 5) >>> cat1 = torch.cat((ones, zeros), dim=0) >>> cat1.shape torch.Size([4, 3, 4, 5]) >>> cat2 = torch.cat((ones, zeros), dim=1) >>> cat2.shape torch.Size([2, 6, 4, 5]) >>> cat3 = torch.cat((ones, zeros), dim=2) >>> cat3.shape torch.Size([2, 3, 8, 5]) >>> cat4 = torch.cat((ones, zeros), dim=3) >>> cat4.shape torch.Size([2, 3, 4, 10])
从这个例子可以很明显的看出,cat *** 作的dim,是依据张量围度从左往右进行计算的。所以很多网上所说的dim=0是沿着X轴、dim=1是沿着Y轴,dim=2是沿着Z轴这种说法是十分不准确的。
所以更准确的说法应该是:dim,指定 *** 作沿着张量的第N位执行指令。 对于上面这个例子来说,执行cat *** 作时,dim=0,即指定以张量维度第0位,执行拼接 *** 作。
验证为了证明结果,这里执行一个小程序片段,我们分别打印 cat1 和 cat4,看看同样的执行顺序会分别输出什么内容
dim0, dim1, dim2, dim3 = cat1.shape for i in range(dim0): print("i:", i, end=" ") for j in range(dim1): print("j:", j) for k in range(dim2): for l in range(dim3): print(cat1[i, j, k, l].item(), end=" ") print("") print("") print("-----------------------------") dim0, dim1, dim2, dim3 = cat4.shape for i in range(dim0): print("i:", i, end=" ") for j in range(dim1): print("j:", j) for k in range(dim2): for l in range(dim3): print(cat4[i, j, k, l].item(), end=" ") print("") print("") i: 0 j: 0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 j: 1 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 j: 2 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 i: 1 j: 0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 j: 1 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 j: 2 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 i: 2 j: 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 j: 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 j: 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 i: 3 j: 0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 j: 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 j: 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ----------------------------- i: 0 j: 0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 j: 1 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 j: 2 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 i: 1 j: 0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 j: 1 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 j: 2 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)