PyTorch的合并与分割

PyTorch的合并与分割,第1张

PyTorch的合并与分割 一、合并 1. cat 函数
    规则:
      所合并的数据的dim一致要合并的维度上shape可以不一致,其余的shape必须一致理解:[class,student]合并=>[class,student],合并后的班级在含以上相同
    例子
a = torch.rand(4,3,16,32)
b = torch.rand(4,3,16,32)
# 第一参数,list,第二参数,再哪个维度合并
print(torch.cat([a,b],dim=2).shape)
# out: torch.Size([4, 3, 32, 32])
2.stack函数
    规则:
      要合并的两个维度必须一致会在合并的维度前插入一个新的维度理解:[class,student]合并=>[class_id,class,student],相当于合并后两个班级分开,意义上不同,比如dim=0维度上的班级是尖子班,dim=1维度上的班级是普通班。
    例子
print(' stack ')
print(torch.stack([a,b],dim=2).shape) # torch.Size([4, 3, 2, 16, 32]) ,在dim2插入一个新的维度

a = torch.rand(3,5)
b = torch.rand(3,5)
print(torch.stack([a,b],dim=0).shape)
# out : torch.Size([2, 3, 5])
二、分割 1. split函数
    规则:
      给定在某一维度拆分后长度给定在某一维度拆分后的每个长度
    例子
a = torch.rand(32,8)
b = torch.rand(32,8)
c = torch.stack([a,b],dim=0)
aa,bb = c.split([1,1],dim=0) # 给定要拆分的dim=0的shape
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([1, 32, 8])  bb.shape= torch.Size([1, 32, 8])
aa,bb = c.split(1,dim=0) #在0维度差分成长度为1
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([1, 32, 8])  bb.shape= torch.Size([1, 32, 8])

c = torch.rand(9,4)
aa,bb,cc = c.split(3,dim=0) # 拆分成3个size[3,4]
print(aa.shape,bb.shape,cc.shape)
# torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])
aa,bb,cc = c.split([2,3,4],dim=0)
print(aa.shape,bb.shape,cc.shape)
# torch.Size([2, 4]) torch.Size([3, 4]) torch.Size([4, 4])
2. chuck函数
    规则
      参数:要分割的数量(不足的向上取整),和维度
    例子
a = torch.rand(4,5,2)
aa,bb = a.chunk(2,dim=0)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([2, 5, 2])  bb.shape= torch.Size([2, 5, 2])
aa,bb = a.chunk(2,dim=1)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([4, 3, 2])  bb.shape= torch.Size([4, 2, 2])

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5720396.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-18
下一篇 2022-12-18

发表评论

登录后才能评论

评论列表(0条)

保存