在TensorFlow中使用预训练的单词嵌入(word2vec或Glove)

在TensorFlow中使用预训练的单词嵌入(word2vec或Glove),第1张

在TensorFlow中使用预训练的单词嵌入(word2vec或Glove)

您可以通过多种方式在TensorFlow中使用预训练的嵌入。假设您将NemPy数组嵌入到

embedding
具有
vocab_size
行和
embedding_dim
列的NumPy数组中,并且想要创建一个
W
可用于调用的张量
tf.nn.embedding_lookup()

  1. 只需创建
    W
    一个
    tf.constant()
    是需要
    embedding
    为它的价值:
    W = tf.constant(embedding, name="W")

这是最简单的方法,但是由于a的值

tf.constant()
多次存储在内存中,因此内存使用效率不高。由于
embedding
可能很大,因此只应将这种方法用于玩具示例。

  1. 创建

    W
    为a,
    tf.Variable
    并通过NumPy数组对其进行初始化
    tf.placeholder()

    W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]), trainable=False, name="W")

    embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim])
    embedding_init = W.assign(embedding_placeholder)

    sess = tf.Session()

    sess.run(embedding_init, feed_dict={embedding_placeholder: embedding})

这样可以避免

embedding
在图表中存储的副本,但确实需要足够的内存才能一次在内存中保留矩阵的两个副本(一个用于NumPy数组,一个用于
tf.Variable
)。请注意,我假设您想在训练期间保持嵌入矩阵不变,因此
W
是使用创建的
trainable=False

  1. 如果将嵌入训练为另一个TensorFlow模型的一部分,则可以使用

    tf.train.Saver
    从另一个模型的检查点文件中加载值。这意味着嵌入矩阵可以完全绕过Python。
    W
    按照选项2创建,然后执行以下 *** 作:

    W = tf.Variable(...)

    embedding_saver = tf.train.Saver({“name_of_variable_in_other_model”: W})

    sess = tf.Session()
    embedding_saver.restore(sess, “checkpoint_filename.ckpt”)



欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5639923.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存