torch.gather(input, dim, index, out=None) → Tensor:其实就是用来挑选输入张量中特定位置的值的
但有一点需要注意的是这里的dim和index之间的关系:这里的dim不再是说对输入进行 *** 作,而是说index中的序列号应该按照输入中的哪个维度来排列
重新理解Pytorch中torch.gather函数祥解 - 简书 (jianshu.com)
eg1:(例子参考自:Pytorch中torch.gather函数祥解 - 简书 (jianshu.com))
这里应该这么理解,dim = 0,所以index应该按照input的行来编号,也就是说input中按照dim = 0来分的元素是[[1,2]],[[3,4]],那么index中的序号也应该是[[a,b]]的形式,其中a,b其实取决于有多少行,也就是说这里其实是按列来取值,[a,b]分别表示,第0列取第a行的数,第1列取第b行的数字。
eg2:
此时dim = 1,所以是从输入的dim = 1来划分,划分结果为两列,分别为[[1],[3]],[[2],[4]],所以此时index是对照列来排序,所以index应该是[[a],[b]]的形式,index中的[a]表示的是第0行取第a列的值,也就是1,而[b]表示第1行的第b列值,也就是4.
eg3:
这里dim = 0,所以划分为两行[1,2],[3,4],则index中[0,0]中第一个0表示的是第0列取第0行的数字1,第二个0个表示第二列取第0行的值2.。
[1,0]则表示第1列取第1行的数字3,第2列取第0行的数字2。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)