在python很多与numpy相关的库和深度学习框架中,经常会涉及到axis这个参数,也有很多博主根据自己的实验结果,给出了一个规律,但因为我们找到的规律太过复杂,不方记忆,也不方便使用。所以本文以np.stack()
函数为例,去讲解axis这个参数的官方解释。
numpy.stack(array, axis=0)
, stack堆叠的意思
为了方便理解,我们先注意该函数的返回值:
return
: 返回一个array数组,返回数组的维度比堆叠的单个数组维度多一维!且多的那一个维度为堆叠array的个数!例如要堆叠的数组是二维,shape为(5,4),要堆叠的数组个数为3个,那么返回的结果一定是三维,且结果一定是(3,5,4),(5,3,4),(5,4,3)中的一个。可以注意到结果为3插在不同的位置。这个位置也正是我们下面要说的axis参数的真正含义!
Parameters
:
array:一个array的集合,要求集合中的每一个数组的shape必须是相同的
axis: 该参数规定了在结果返回的array中,多的那一个维度(堆叠数组的数量),在结果数组中的哪一个维度。axis =0 表示在第一个维度,axis=-1表示在最后一个维度。
代码实例# 定义三个array
a = np.zeros([5,4],dtype=np.uint8)
b = np.zeros([5,4],dtype=np.uint8)
c = np.zeros([5,4],dtype=np.uint8)
x = np.stack((a,b,c),axis=0)
x.shape
>>> output
(3, 5, 4)
np.stack((a,b,c),axis=1).shape
>>> output
(5, 3, 4)
np.stack((a,b,c),axis=-1).shape
>>> output
(5, 4, 3)
记忆说明
博主之前也一直纠结于axis与返回结果之间的具体关系(怎么堆叠),但是根据写代码的过程中,我发现按照上述shape变换的规律就可以确定参数,以满足我们大部分的要求。而不需要我们去了解具体的堆叠方法。
相关网址numpy 官方网址np.stack()
其他博主根据结果找到的规律,如果有小伙伴依旧想找到其中的规律,可以看这个博主的文章。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)