python-不同shape的列表数组的赋值 *** 作,最终会导致被赋值的list被截断,shape改变;什么是广播机制、broadcast机制?原理?什么时候会触发?

python-不同shape的列表数组的赋值 *** 作,最终会导致被赋值的list被截断,shape改变;什么是广播机制、broadcast机制?原理?什么时候会触发?,第1张

一、问题背景

最近,在博客园发现一篇利用预训练网络训练CNN分类模型的文章,在代码中发现了一个问题。

如果sample_count的数目不等于batch_size 的整数倍,那么遍历到数据末尾时,必然会导致feature_batch的长度小于batch_size,也即下面图中划线的部分中等号前后的len不一样。

所以,我在这里滋生了一个疑惑,当python的内置数据类型list,或者numpy中的ndarray类型数据,在赋值时也出现这种情况,会怎么输出?

二、实验过程

2.1 python内置数据类型list
list1 = [1,2,3,4,5]
list2 = [6,7]

list1[1:5]=list2

print(list1)
# [1, 6, 7] 正常输出;规律是从最小索引开始赋值,能赋多少赋多少,后面没办法赋值的元素自动去除

2.2 ndarray类型数据(无法触发广播)
import numpy as np

arr1 = np.array([[1,2,3,4,5]])

arr2 = np.array([6,7])

arr1[1:5]=arr2

print(arr1)  # ValueError: could not broadcast input array from shape (2,) into shape (4,)

2.3 ndarray类型数据(可触发广播)
import numpy as np

arr1 = np.array([[1,2,3,4,5],[6,7,8,9,10]])

arr2 = np.array([[1,2,3,6,7]])

arr1[0:2]=arr2

print(arr1)  

# [[1 2 3 6 7]
#  [1 2 3 6 7]]
# 正常输出,说明触发broadcast机制(本文的第三部分有讲到)

而明显,本文的第一部分中引用的代码,明显属于ndarray数组类型的数据。

能不能实现广播机制呢?其实是不能的,因为从后向前数,第0维上的shape不相等,并且features_batch的shape一旦不是(1,4,4,512),就无法触发广播机制,所以我觉得作者的代码存在问题。

(sorry,对这个作者说句抱歉!大概是因为generator这个函数只能输出batch_size形状的数据,所以“大概”是不存在这个问题的,但是问题在于文件夹内剩余的没有batch_size的数据呢?完了又有新的疑问......可能会遗弃剩余的吧,不然会出错呢!) 

2.4 多维list不等shape赋值规律(永远无法触发广播机制)
list1 = [[1,2,3,4,5],[6,7,8,9,10],[11,12,22,323,121]]
list2 = [[[1,2,3]]]

list1[0:2]=list2

print(list1)  # [[[1, 2, 3]], [11, 12, 22, 323, 121]]

可以看出规律:list1[0:2]切片相当于取出了[1,2,3,4,5],[6,7,8,9,10]这个部分, 而list2在赋值表达式中相当于[[1,2,3]],所以把后者替换到前者,就可以得到结果[[[1, 2, 3]], [11, 12, 22, 323, 121]]了。

很明显,list不存在广播机制,不然这种赋值 *** 作是不允许存在的。

因为广播机制的前提是输入输出的容器中各个维度上的shape一定要相等。

但是上面的list1被赋值之后,第1个维度下的两个列表shape分别是1*3和5;说明list1不存在这个机制。

也正是因为numpy.ndarray的广播机制存在,所以在将list转成array的时候,一般是不建议将各维度下不等shape的list转成数组的。

如下所示:

np.array([[[1, 2, 3]], [11, 12, 22, 323, 121]])

# 警告:C:\Users\Administrator\AppData\Local\Temp\ipykernel_1072844013573.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
# 输出:np.array([[[1, 2, 3]], [11, 12, 22, 323, 121]])

2.5 ndarray类型数据(没必要触发广播)

因为等号左右两边shape相同,所以没必要广播。

import numpy as np

arr1 = np.array([[1,2,3,4,5],[6,7,8,9,10],[213,214,324,235,44]])

arr2 = np.array([[1,2,3,6,7],[1231,2142,2412,221,214]])

arr1[0:2]=arr2

print(arr1[0:2].shape)
print(arr2.shape)
print(arr1)  

# (2, 5)
# (2, 5)
# [[   1    2    3    6    7]
#  [1231 2142 2412  221  214]
#  [ 213  214  324  235   44]]

三、什么是广播机制?原理?什么时候会触发?

3.1 当两者中有一个标量时,也即shape=(n, ),必触发
A = np.zeros((2,5,3,4))
B = np.ones((1))

print((A+B).shape)  # (2, 5, 3, 4)

3.2 当两者的维度数目相同,从后向前数各个维度上的shape相同或者有一方为1时,触发
A = np.zeros((2,5,3,4))
B = np.ones((1,5,1,4))

print((A+B).shape)  # (2, 5, 3, 4)

3.3 当两者的维度数目不同,一方少于另一方,从后向前数都有的维度上的shape相同或者有一方为1时,也能触发
A = np.zeros((2,5,3,4))
B = np.ones((8,1,5,1,4))

print((A+B).shape)  # (8, 2, 5, 3, 4)

3.4 当两者的维度数目相同,从后向前数,有一个维度上的shape不同并且不为1时,出错
A = np.zeros((2,5,3,4))
B = np.ones((2,5,3,2))

print((A+B).shape)  # ValueError: operands could not be broadcast together with shapes (2,5,3,4) (2,5,3,2) 

原因在于,从后向前数,前者的第一个数是4,后者是2,两者不同,并且其中没有一个是1,所以无法广播,因此出错。 

A = np.zeros((2,5,3,4))
B = np.ones((2,5,3,1))

print((A+B).shape)  # (2, 5, 3, 4)

 从后向前数,前者的第一个数是4,后者是1,两者不同,但是后者是1,所以可广播,因此没出事。 

3.5 最后对广播机制进行一个总结

从3.3这个实验结果,最容易得到一个广泛适用的结论。

  • 从后向前数,逐一对比shape上的各个数值
  • 如果数值相等,或者即使不相等但是有一方为1,就可以继续向前对比
  • 一直对比到有一方的数值被用完,才结束
  • 如果各个数值的对比结果符合条件,就符合“广播兼容性”,可以广播!

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/714976.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-25
下一篇 2022-04-25

发表评论

登录后才能评论

评论列表(0条)

保存