tf.data.Dataset函数

tf.data.Dataset函数,第1张

函数原型
tf.data.Dataset(
    variant_tensor
)
函数说明

通常使用tf.data.Dataset.from_tensor_slices函数来创建一个Dataset对象,Dataset用于表示一个数据集,这个数据集是可以迭代的。from_tensor_slices函数的原型为:from_tensor_slices(tensors, name=None),该函数的作用是将tensor张量在第一个维度上切分并转换成Dataset对象,包含多个元素。

还有个类似的函数from_tensors,函数原型为:from_tensors(tensors, name=None),该函数的作用仅仅就是将tensor张量转换成Dataset对象,只包含一个元素。这两个方法都是静态方法。

还有一些常用的方法:
as_numpy_iterator():将Dataset对象转换成一个iter迭代器,由于Dataset是可以迭代的,但是不能查看数据,所以可以通过该方法查看具体元素的数据。

batch(batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None, name=None):将此数据集的batch_size个元素组合成一个元素。比如,组合前的数据形状为(8,8),有4个数据元素,使用batch(2)后,数据形状就会变成(2, 8, 8),只有两个数据元素。

padded_batch(batch_size, padded_shapes=None, padding_values=None, drop_remainder=False, name=None):在batch函数的基础上增加一个填充的功能,目的就是让数据对齐。

unbatch(name=None):batch函数的反向过程。

shard(num_shards, index, name=None):将Dataset进行分片。

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None, name=None):随机的打乱数据集中的元素。reshuffle_each_iteration控制每个时期的打乱顺序是否应该不同。创建 epoch 的惯用方式是通过repeat转换。

repeat(count=None, name=None):重复dataset count次。

prefetch(buffer_size, name=None):buffer_size表示从Dataset中预取的元素数量。函数在处理当前元素的同时准备后面的元素。这通常会提高延迟和吞吐量,但代价是使用额外的内存来存储预取的元素。

cache(filename=‘’, name=None):第一次迭代数据集时,其元素将缓存在指定文件或内存中。随后的迭代将使用缓存的数据。

concatenate(dataset, name=None):连接两个Dataset,形状要一致。

skip(count, name=None):跳过Dataset的前count个元素。也就是获取后面的元素。

take(count, name=None):获取Dataset的前count个元素。

take_while(predicate, name=None):根据条件predicate过滤点Dataset中的数据。

filter(predicate, name=None):根据条件predicate过滤点Dataset中的数据。

map(map_func, num_parallel_calls=None, deterministic=None, name=None):对于Dataset中的每一个元素都应用map_func。

函数使用

例一:from_tensor_slices()和from_tensors()区别

# from_tensor_slices()
>>> d = tf.reshape(tf.range(15), (5, 3))
>>> d

>>> data = tf.data.Dataset.from_tensor_slices(d)
>>> data

>>> for f in data:
	f






>>> list(data.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8]), array([ 9, 10, 11]), array([12, 13, 14])]

# from_tensors()
>>> d = tf.reshape(tf.range(15), (5, 3))
>>> data = tf.data.Dataset.from_tensors(d)
>>> data

>>> for f in data:
	f



例二:使用batch、padded_batch、unbatch、shard 合并拆分元素

# batch函数,用于处理Dataset中长度一致的元素
>>> a = tf.reshape(tf.range(24), (8, 3))
>>> a

>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8]), array([ 9, 10, 11]), array([12, 13, 14]), array([15, 16, 17]), array([18, 19, 20]), array([21, 22, 23])]
>>> dataset_a = dataset.batch(3)
>>> for data in dataset_a:
	data




# 参数drop_remainder作用就是去掉最后小于batch_size的元素
>>> dataset_b = dataset.batch(3, drop_remainder=True)
>>> for data in dataset_b:
	data





# padded_batch函数,用于处理长度不一的元素
>>> a = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
>>> dataset = tf.data.Dataset.from_generator(lambda :iter(a), tf.int32)
>>> list(dataset.as_numpy_iterator())
[array([1, 2, 3]), array([4, 5]), array([6, 7, 8, 9])]
>>> for data in dataset:
	data




# padded_shapes指定填充的长度
>>> dataset_a = dataset.padded_batch(2, padded_shapes=4)
>>> for data in dataset_a:
	data

	


# padding_values指定填充值
>>> dataset_b = dataset.padded_batch(2, padded_shapes=4, padding_values=-1)
>>> for data in dataset_b:
	data

	



# padded_batch处理(train_data, train_labels)元组数据的情况
>>> a = [([1, 2, 3], [10, 11]),
     ([5, 6], [12, 13, 14]),
     ([7], [15, 16])]
>>> dataset = tf.data.Dataset.from_generator(lambda: iter(elements), (tf.int32, tf.int32))
>>> for data in dataset:
	data

(, )
(, )
>>> dataset = dataset.padded_batch(2,padded_shapes=([4], [5]),padding_values=(-1, 0))
>>> for data in dataset:
	data

(, >> a = tf.ones((3, 5))
>>> a

>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> dataset_a = dataset.batch(2)
>>> for data in dataset_a:
	data

	


>>> dataset_b = dataset_a.unbatch()
>>> for data in dataset_b:
	data





# shard函数,获取Dataset的子集,按1/num_shards划分
>>> a = tf.reshape(tf.range(24), (8, 3))
>>> a

>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> dataset_a = dataset.shard(3, 0)
>>> for data in dataset_a:
	data




>>> dataset_b = dataset.shard(3, 1)
>>> for data in dataset_b:
	data




例3:skip、take、take_while、filter、map函数


>>> a = tf.reshape(tf.range(12), (6, 2))
>>> a

>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> for data in dataset:
	data







# skip函数
>>> dataset_a = dataset.skip(2)
>>> for data in dataset_a:
	data





# take函数
>>> dataset_b = dataset.take(4)
>>> for data in dataset_b:
	data





# 
>>> a = [[1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]
>>> a
[[1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]
>>> dataset = tf.data.Dataset.from_generator(lambda :iter(a), output_types=tf.int32)
>>> for data in dataset:
	data





# take_while函数,过滤掉长度大于3的元素
>>> dataset_a = dataset.take_while(lambda x: tf.shape(x)[0]<3)
>>> for data in dataset_a:
	data



# filter函数,与take_while函数功能相同
>>> dataset_a = dataset.filter(lambda x: tf.shape(x)[0]<3)
>>> for data in dataset_a:
	data



# map函数,将所有元素加1
dataset_b = dataset.map(lambda x: x+1)
>>> for data in dataset_b:
	data





例4:shuffle、repeat函数

>>> a = tf.reshape(tf.range(24), (8, 3))
>>> a

>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> for data in dataset:
	data










# 没有设置参数reshuffle_each_iteration,或者reshuffle_each_iteration=True
>>> dataset_a = dataset.shuffle(4)
>>> for data in dataset_a:
	data









>>> dataset_b = dataset_a.repeat(1)
>>> for data in dataset_b:
	data









# 参数reshuffle_each_iteration=False
>>> dataset_a = dataset.shuffle(4, reshuffle_each_iteration=False)
>>> for data in dataset_a:
	data









>>> dataset_b = dataset_a.repeat(1)
>>> for data in dataset_b:
	data









欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/904477.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-15
下一篇 2022-05-15

发表评论

登录后才能评论

评论列表(0条)

保存