假定输入的数据为 A:=[0,1,2,3,4,5,6,7,8], buffer_sizer:=2.
A.shuffle(2)的工作思路:
1.取A的 前两个元素 {0,1} 从中随机选出 1个元素, 例如 1. 原先的数据变为 A=[0,2,3,4,5,6,7,8]
2. 重复1 至不能再重复
3. 将选出的元素按顺序写下.
上面的思路可以推论出 A 中第 i 个元素在 shuffle 后 只能出现在 第 i-1 个位置或更靠后.
dataset = tf.data.Dataset.from_tensor_slices([0, 1, 2,3,4, 5,6,7,8]) dataset1 = dataset.shuffle(2) dataset = dataset.batch(3) for elem in dataset1: print(elem) #输出为 tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) tf.Tensor(4, shape=(), dtype=int32) tf.Tensor(3, shape=(), dtype=int32) tf.Tensor(5, shape=(), dtype=int32) tf.Tensor(7, shape=(), dtype=int32) tf.Tensor(6, shape=(), dtype=int32) tf.Tensor(8, shape=(), dtype=int32)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)