如何正确组合TensorFlow的数据集API和Keras?

如何正确组合TensorFlow的数据集API和Keras?,第1张

如何正确组合TensorFlow的数据集API和Keras?

确实有一种更有效的使用方法,Dataset而不必将张量转换为numpy数组。但是,官方文档上还没有(尚未?)。在发行说明中,它是Keras 2.0.7中引入的功能。您可能需要安装keras> = 2.0.7才能使用它。

x = np.arange(4).reshape(-1, 1).astype('float32')ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)it_x = ds_x.make_one_shot_iterator()y = np.arange(5, 9).reshape(-1, 1).astype('float32')ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)it_y = ds_y.make_one_shot_iterator()input_vals = Input(tensor=it_x.get_next())output = Dense(1, activation='relu')(input_vals)model = Model(inputs=input_vals, outputs=output)model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])model.fit(steps_per_epoch=1, epochs=5, verbose=2)

几个区别:

  1. tensor
    参数提供给该Input层。Keras将从该张量读取值,并将其用作拟合模型的输入。
  2. target_tensors
    论据提供给
    Model.compile()
  3. 请记住将x和y都转换为float32。在正常使用情况下,Keras将为您完成此转换。但是现在您必须自己做。
  4. 批处理大小是在构建时指定的Dataset。使用
    steps_per_epoch
    和epochs控制何时停止模型拟合。
    总之,使用
    Input(tensor=...)
    model.compile(target_tensors=...)
    并且
    model.fit(x=None, y=None, ...)
    如果你的数据是从张量读取。


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5617298.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存