>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) >>> x tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> torch.roll(x, 1, 0) tensor([[7, 8], [1, 2], [3, 4], [5, 6]]) >>> torch.roll(x, -1, 0) tensor([[3, 4], [5, 6], [7, 8], [1, 2]]) >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) tensor([[6, 5], [8, 7], [2, 1], [4, 3]])自己理解及示例
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]).view(4, 4) print(x) >>>tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]]) # dims=0, 滚动行;dims=1, 滚动列。shifts>0, 向上或向左滚动;shifts<0, 向下或向右滚动。 y = torch.roll(x, shifts=(2, 2), dims=(0, 1)) print(y) >>>tensor([[11, 12, 9, 10], [15, 16, 13, 14], [ 3, 4, 1, 2], [ 7, 8, 5, 6]]) 过程详解: (1)dims=(0, 1), shifts=(2, 2), 表示向上滚动行2次,向左滚动列2次。 (2)向上滚动行2次: tensor([[ 1, 2, 3, 4], tensor([[ 9, 10, 11, 12], [ 5, 6, 7, 8], [13, 14, 15, 16], [ 9, 10, 11, 12], ---> [ 1, 2, 3, 4], [13, 14, 15, 16]]) [ 5, 6, 7, 8]]) (3)向左滚动列2次: tensor([[ 9, 10, 11, 12], tensor([[11, 12, 9, 10], [13, 14, 15, 16], [15, 16, 13, 14], [ 1, 2, 3, 4], ---> [ 3, 4, 1, 2], [ 5, 6, 7, 8]]) [ 7, 8, 5, 6]]) # roll reverse,变回原样 z = torch.roll(y, shifts=(-2, -2), dims=(0, 1)) print(z) >>>tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)