如何在Tensorflow中使用自定义python函数预取数据

如何在Tensorflow中使用自定义python函数预取数据,第1张

如何在Tensorflow中使用自定义python函数预取数据

这是一个常见的用例,大多数实现都使用TensorFlow的 队列
将预处理代码与训练代码分离。有一个有关如何使用队列的教程,但是主要步骤如下:

  1. 定义一个队列

    q
    它将缓冲预处理的数据。TensorFlow支持以
    tf.FIFOQueue
    排队顺序生成元素的简单方法,以及
    tf.RandomShuffleQueue
    以随机顺序生成元素的高级方法。队列元素是一个或多个张量(可以具有不同的类型和形状)的元组。所有队列都支持单元素(
    enqueue
    dequeue
    )和批处理(
    enqueue_many
    dequeue_many
    ) *** 作,但是要使用批处理 *** 作,必须在构造队列时在队列元素中指定每个张量的形状。

  2. 构建一个子图,将预处理的元素排入队列。一种方法是

    tf.placeholder()
    为张量定义一些与单个输入示例对应的 *** 作,然后将它们传递给
    q.enqueue()
    。(如果预处理一次生成一批,则应
    q.enqueue_many()
    改用。)您也可以在此子图中包括TensorFlow op。

  3. 建立执行训练的子图。这看起来像一个普通的TensorFlow图,但是会通过调用来获取其输入

    q.dequeue_many(BATCH_SIZE)

  4. 开始会话。

  5. 创建一个或多个执行预处理逻辑的线程,然后执行入队 *** 作,并输入预处理后的数据。您可能会发现

    tf.train.Coordinator
    tf.train.QueueRunner
    实用程序类对此有用。

  6. 正常运行训练图(优化器等)。

编辑: 这是一个简单的

load_and_enqueue()
功能和代码片段,可以帮助您入门:

    # Features are length-100 vectors of floats    feature_input = tf.placeholder(tf.float32, shape=[100])    # Labels are scalar integers.    label_input = tf.placeholder(tf.int32, shape=[])    # Alternatively, could do:    # feature_batch_input = tf.placeholder(tf.float32, shape=[None, 100])    # label_batch_input = tf.placeholder(tf.int32, shape=[None])    q = tf.FIFOQueue(100, [tf.float32, tf.int32], shapes=[[100], []])    enqueue_op = q.enqueue([feature_input, label_input])    # For batch input, do:    # enqueue_op = q.enqueue_many([feature_batch_input, label_batch_input])    feature_batch, label_batch = q.dequeue_many(BATCH_SIZE)    # Build rest of model taking label_batch, feature_batch as input.    # [...]    train_op = ...    sess = tf.Session()    def load_and_enqueue():      with open(...) as feature_file, open(...) as label_file:        while True:          feature_array = numpy.fromfile(feature_file, numpy.float32, 100)          if not feature_array: return          label_value = numpy.fromfile(feature_file, numpy.int32, 1)[0]          sess.run(enqueue_op, feed_dict={feature_input: feature_array,         label_input: label_value})    # Start a thread to enqueue data asynchronously, and hide I/O latency.    t = threading.Thread(target=load_and_enqueue)    t.start()    for _ in range(TRAINING_EPOCHS):      sess.run(train_op)


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5644342.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存