torch.gather

torch.gather,第1张

torch.gather(input, dim, index, out=None) → Tensor:其实就是用来挑选输入张量中特定位置的值的

但有一点需要注意的是这里的dim和index之间的关系:这里的dim不再是说对输入进行 *** 作,而是说index中的序列号应该按照输入中的哪个维度来排列

重新理解Pytorch中torch.gather函数祥解 - 简书 (jianshu.com)

eg1:(例子参考自:Pytorch中torch.gather函数祥解 - 简书 (jianshu.com))

这里应该这么理解,dim = 0,所以index应该按照input的行来编号,也就是说input中按照dim = 0来分的元素是[[1,2]],[[3,4]],那么index中的序号也应该是[[a,b]]的形式,其中a,b其实取决于有多少行,也就是说这里其实是按列来取值,[a,b]分别表示,第0列取第a行的数,第1列取第b行的数字。


 eg2:

 此时dim = 1,所以是从输入的dim = 1来划分,划分结果为两列,分别为[[1],[3]],[[2],[4]],所以此时index是对照列来排序,所以index应该是[[a],[b]]的形式,index中的[a]表示的是第0行取第a列的值,也就是1,而[b]表示第1行的第b列值,也就是4.

eg3:

 这里dim = 0,所以划分为两行[1,2],[3,4],则index中[0,0]中第一个0表示的是第0列取第0行的数字1,第二个0个表示第二列取第0行的值2.。


[1,0]则表示第1列取第1行的数字3,第2列取第0行的数字2。


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/571685.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存