如何使用TensorFlow tf.train.string_input_producer生成几个纪元数据?

如何使用TensorFlow tf.train.string_input_producer生成几个纪元数据?,第1张

如何使用TensorFlow tf.train.string_input_producer生成几个纪元数据?

正如Nicolas观察到的那样,

tf.train.string_input_producer()
API不能让您检测到何时到达时代的尽头。而是将所有纪元串联在一起,组成一个较长的批处理。因此,我们最近(在TensorFlow1.2中)添加了
tf.contrib.data
API,该API使得可以表达更复杂的管道,包括您的用例。

以下代码段显示了如何使用编写程序

tf.contrib.data

import tensorflow as tfdef input_pipeline(filenames, batch_size):    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.    dataset = (tf.contrib.data.TextLineDataset(filenames)    .map(lambda line: tf.depre_csv(         line, record_defaults=[['1'], ['1'], ['1']], field_delim='-'))    .shuffle(buffer_size=10)  # Equivalent to min_after_dequeue=10.    .batch(batch_size))    # Return an *initializable* iterator over the dataset, which will allow us to    # re-initialize it at the beginning of each epoch.    return dataset.make_initializable_iterator()filenames=['1.txt']batch_size = 3num_epochs = 10iterator = input_pipeline(filenames, batch_size)# `a1`, `a2`, and `a3` represent the next element to be retrieved from the iterator.    a1, a2, a3 = iterator.get_next()with tf.Session() as sess:    for _ in range(num_epochs):        # Resets the iterator at the beginning of an epoch.        sess.run(iterator.initializer)        try: while True:     a, b, c = sess.run([a1, a2, a3])     print(a, b, c)        except tf.errors.OutOfRangeError: # This will be raised when you reach the end of an epoch (i.e. the # iterator has no more elements). pass        # Perform any end-of-epoch computation here.        print('Done training, epoch reached')


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5630640.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存