tf.data.Dataset(
variant_tensor
)
函数说明
通常使用tf.data.Dataset.from_tensor_slices函数来创建一个Dataset对象,Dataset用于表示一个数据集,这个数据集是可以迭代的。from_tensor_slices函数的原型为:from_tensor_slices(tensors, name=None),该函数的作用是将tensor张量在第一个维度上切分并转换成Dataset对象,包含多个元素。
还有个类似的函数from_tensors,函数原型为:from_tensors(tensors, name=None),该函数的作用仅仅就是将tensor张量转换成Dataset对象,只包含一个元素。这两个方法都是静态方法。
还有一些常用的方法:
as_numpy_iterator():将Dataset对象转换成一个iter迭代器,由于Dataset是可以迭代的,但是不能查看数据,所以可以通过该方法查看具体元素的数据。
batch(batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None, name=None):将此数据集的batch_size个元素组合成一个元素。比如,组合前的数据形状为(8,8),有4个数据元素,使用batch(2)后,数据形状就会变成(2, 8, 8),只有两个数据元素。
padded_batch(batch_size, padded_shapes=None, padding_values=None, drop_remainder=False, name=None):在batch函数的基础上增加一个填充的功能,目的就是让数据对齐。
unbatch(name=None):batch函数的反向过程。
shard(num_shards, index, name=None):将Dataset进行分片。
shuffle(buffer_size, seed=None, reshuffle_each_iteration=None, name=None):随机的打乱数据集中的元素。reshuffle_each_iteration控制每个时期的打乱顺序是否应该不同。创建 epoch 的惯用方式是通过repeat转换。
repeat(count=None, name=None):重复dataset count次。
prefetch(buffer_size, name=None):buffer_size表示从Dataset中预取的元素数量。函数在处理当前元素的同时准备后面的元素。这通常会提高延迟和吞吐量,但代价是使用额外的内存来存储预取的元素。
cache(filename=‘’, name=None):第一次迭代数据集时,其元素将缓存在指定文件或内存中。随后的迭代将使用缓存的数据。
concatenate(dataset, name=None):连接两个Dataset,形状要一致。
skip(count, name=None):跳过Dataset的前count个元素。也就是获取后面的元素。
take(count, name=None):获取Dataset的前count个元素。
take_while(predicate, name=None):根据条件predicate过滤点Dataset中的数据。
filter(predicate, name=None):根据条件predicate过滤点Dataset中的数据。
map(map_func, num_parallel_calls=None, deterministic=None, name=None):对于Dataset中的每一个元素都应用map_func。
函数使用例一:from_tensor_slices()和from_tensors()区别
# from_tensor_slices()
>>> d = tf.reshape(tf.range(15), (5, 3))
>>> d
>>> data = tf.data.Dataset.from_tensor_slices(d)
>>> data
>>> for f in data:
f
>>> list(data.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8]), array([ 9, 10, 11]), array([12, 13, 14])]
# from_tensors()
>>> d = tf.reshape(tf.range(15), (5, 3))
>>> data = tf.data.Dataset.from_tensors(d)
>>> data
>>> for f in data:
f
例二:使用batch、padded_batch、unbatch、shard 合并拆分元素
# batch函数,用于处理Dataset中长度一致的元素
>>> a = tf.reshape(tf.range(24), (8, 3))
>>> a
>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8]), array([ 9, 10, 11]), array([12, 13, 14]), array([15, 16, 17]), array([18, 19, 20]), array([21, 22, 23])]
>>> dataset_a = dataset.batch(3)
>>> for data in dataset_a:
data
# 参数drop_remainder作用就是去掉最后小于batch_size的元素
>>> dataset_b = dataset.batch(3, drop_remainder=True)
>>> for data in dataset_b:
data
# padded_batch函数,用于处理长度不一的元素
>>> a = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
>>> dataset = tf.data.Dataset.from_generator(lambda :iter(a), tf.int32)
>>> list(dataset.as_numpy_iterator())
[array([1, 2, 3]), array([4, 5]), array([6, 7, 8, 9])]
>>> for data in dataset:
data
# padded_shapes指定填充的长度
>>> dataset_a = dataset.padded_batch(2, padded_shapes=4)
>>> for data in dataset_a:
data
# padding_values指定填充值
>>> dataset_b = dataset.padded_batch(2, padded_shapes=4, padding_values=-1)
>>> for data in dataset_b:
data
# padded_batch处理(train_data, train_labels)元组数据的情况
>>> a = [([1, 2, 3], [10, 11]),
([5, 6], [12, 13, 14]),
([7], [15, 16])]
>>> dataset = tf.data.Dataset.from_generator(lambda: iter(elements), (tf.int32, tf.int32))
>>> for data in dataset:
data
(, )
(, )
>>> dataset = dataset.padded_batch(2,padded_shapes=([4], [5]),padding_values=(-1, 0))
>>> for data in dataset:
data
(, >> a = tf.ones((3, 5))
>>> a
>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> dataset_a = dataset.batch(2)
>>> for data in dataset_a:
data
>>> dataset_b = dataset_a.unbatch()
>>> for data in dataset_b:
data
# shard函数,获取Dataset的子集,按1/num_shards划分
>>> a = tf.reshape(tf.range(24), (8, 3))
>>> a
>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> dataset_a = dataset.shard(3, 0)
>>> for data in dataset_a:
data
>>> dataset_b = dataset.shard(3, 1)
>>> for data in dataset_b:
data
例3:skip、take、take_while、filter、map函数
>>> a = tf.reshape(tf.range(12), (6, 2))
>>> a
>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> for data in dataset:
data
# skip函数
>>> dataset_a = dataset.skip(2)
>>> for data in dataset_a:
data
# take函数
>>> dataset_b = dataset.take(4)
>>> for data in dataset_b:
data
#
>>> a = [[1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]
>>> a
[[1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]
>>> dataset = tf.data.Dataset.from_generator(lambda :iter(a), output_types=tf.int32)
>>> for data in dataset:
data
# take_while函数,过滤掉长度大于3的元素
>>> dataset_a = dataset.take_while(lambda x: tf.shape(x)[0]<3)
>>> for data in dataset_a:
data
# filter函数,与take_while函数功能相同
>>> dataset_a = dataset.filter(lambda x: tf.shape(x)[0]<3)
>>> for data in dataset_a:
data
# map函数,将所有元素加1
dataset_b = dataset.map(lambda x: x+1)
>>> for data in dataset_b:
data
例4:shuffle、repeat函数
>>> a = tf.reshape(tf.range(24), (8, 3))
>>> a
>>> dataset = tf.data.Dataset.from_tensor_slices(a)
>>> for data in dataset:
data
# 没有设置参数reshuffle_each_iteration,或者reshuffle_each_iteration=True
>>> dataset_a = dataset.shuffle(4)
>>> for data in dataset_a:
data
>>> dataset_b = dataset_a.repeat(1)
>>> for data in dataset_b:
data
# 参数reshuffle_each_iteration=False
>>> dataset_a = dataset.shuffle(4, reshuffle_each_iteration=False)
>>> for data in dataset_a:
data
>>> dataset_b = dataset_a.repeat(1)
>>> for data in dataset_b:
data
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)