您可以使用
advancedindexing-
In [17]: aOut[17]: array([[[ 1, 2], [ 3, 4], [ 8, 6]], [[ 7, 8], [ 9, 8], [ 3, 12]]])In [18]: idx = a.argmax(axis=-1)In [19]: m,n = a.shape[:2]In [20]: a[np.arange(m)[:,None],np.arange(n),idx]Out[20]: array([[ 2, 4, 8], [ 8, 9, 12]])
对于任何数量的维度的通用ndarray情况,如中所述
comments by@hpaulj,我们可以使用
np.ix_,就像这样-
shp = np.array(a.shape)dim_idx = list(np.ix_(*[np.arange(i) for i in shp[:-1]]))dim_idx.append(idx)out = a[dim_idx]
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)