- 规则:
- 所合并的数据的dim一致要合并的维度上shape可以不一致,其余的shape必须一致理解:[class,student]合并=>[class,student],合并后的班级在含以上相同
a = torch.rand(4,3,16,32) b = torch.rand(4,3,16,32) # 第一参数,list,第二参数,再哪个维度合并 print(torch.cat([a,b],dim=2).shape) # out: torch.Size([4, 3, 32, 32])2.stack函数
- 规则:
- 要合并的两个维度必须一致会在合并的维度前插入一个新的维度理解:[class,student]合并=>[class_id,class,student],相当于合并后两个班级分开,意义上不同,比如dim=0维度上的班级是尖子班,dim=1维度上的班级是普通班。
print(' stack ') print(torch.stack([a,b],dim=2).shape) # torch.Size([4, 3, 2, 16, 32]) ,在dim2插入一个新的维度 a = torch.rand(3,5) b = torch.rand(3,5) print(torch.stack([a,b],dim=0).shape) # out : torch.Size([2, 3, 5])二、分割 1. split函数
- 规则:
- 给定在某一维度拆分后长度给定在某一维度拆分后的每个长度
a = torch.rand(32,8) b = torch.rand(32,8) c = torch.stack([a,b],dim=0) aa,bb = c.split([1,1],dim=0) # 给定要拆分的dim=0的shape print('aa.shape=',aa.shape,' bb.shape=',bb.shape) # aa.shape= torch.Size([1, 32, 8]) bb.shape= torch.Size([1, 32, 8]) aa,bb = c.split(1,dim=0) #在0维度差分成长度为1 print('aa.shape=',aa.shape,' bb.shape=',bb.shape) # aa.shape= torch.Size([1, 32, 8]) bb.shape= torch.Size([1, 32, 8]) c = torch.rand(9,4) aa,bb,cc = c.split(3,dim=0) # 拆分成3个size[3,4] print(aa.shape,bb.shape,cc.shape) # torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4]) aa,bb,cc = c.split([2,3,4],dim=0) print(aa.shape,bb.shape,cc.shape) # torch.Size([2, 4]) torch.Size([3, 4]) torch.Size([4, 4])2. chuck函数
- 规则
- 参数:要分割的数量(不足的向上取整),和维度
a = torch.rand(4,5,2) aa,bb = a.chunk(2,dim=0) print('aa.shape=',aa.shape,' bb.shape=',bb.shape) # aa.shape= torch.Size([2, 5, 2]) bb.shape= torch.Size([2, 5, 2]) aa,bb = a.chunk(2,dim=1) print('aa.shape=',aa.shape,' bb.shape=',bb.shape) # aa.shape= torch.Size([4, 3, 2]) bb.shape= torch.Size([4, 2, 2])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)