正如Nicolas观察到的那样,
tf.train.string_input_producer()API不能让您检测到何时到达时代的尽头。而是将所有纪元串联在一起,组成一个较长的批处理。因此,我们最近(在TensorFlow1.2中)添加了
tf.contrib.dataAPI,该API使得可以表达更复杂的管道,包括您的用例。
以下代码段显示了如何使用编写程序
tf.contrib.data:
import tensorflow as tfdef input_pipeline(filenames, batch_size): # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data. dataset = (tf.contrib.data.TextLineDataset(filenames) .map(lambda line: tf.depre_csv( line, record_defaults=[['1'], ['1'], ['1']], field_delim='-')) .shuffle(buffer_size=10) # Equivalent to min_after_dequeue=10. .batch(batch_size)) # Return an *initializable* iterator over the dataset, which will allow us to # re-initialize it at the beginning of each epoch. return dataset.make_initializable_iterator()filenames=['1.txt']batch_size = 3num_epochs = 10iterator = input_pipeline(filenames, batch_size)# `a1`, `a2`, and `a3` represent the next element to be retrieved from the iterator. a1, a2, a3 = iterator.get_next()with tf.Session() as sess: for _ in range(num_epochs): # Resets the iterator at the beginning of an epoch. sess.run(iterator.initializer) try: while True: a, b, c = sess.run([a1, a2, a3]) print(a, b, c) except tf.errors.OutOfRangeError: # This will be raised when you reach the end of an epoch (i.e. the # iterator has no more elements). pass # Perform any end-of-epoch computation here. print('Done training, epoch reached')
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)