首页 > 解决方案 > 特征列预训练嵌入

问题描述

如何使用预训练嵌入与tf.feature_column.embedding_column.

我使用pre_trained嵌入tf.feature_column.embedding_column. 但它不起作用。错误是

错误是:

ValueError:如果指定,初始化程序必须是可调用的。column_name 的嵌入:itemx

这是我的代码:

weight, vocab_size, emb_size = _create_pretrained_emb_from_txt(FLAGS.vocab, 
FLAGS.pre_emb)

W = tf.Variable(tf.constant(0.0, shape=[vocab_size, emb_size]),
                trainable=False, name="W")
embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, emb_size])
embedding_init = W.assign(embedding_placeholder)

sess = tf.Session()
sess.run(embedding_init, feed_dict={embedding_placeholder: weight})

itemx_vocab = tf.feature_column.categorical_column_with_vocabulary_file(
    key='itemx',
    vocabulary_file=FLAGS.vocabx)

itemx_emb = tf.feature_column.embedding_column(itemx_vocab,
                                               dimension=emb_size,
                                               initializer=W,
                                               trainable=False)

我已经尝试过初始化程序 = lambda w:W。像这样:

itemx_emb = tf.feature_column.embedding_column(itemx_vocab,
                                               dimension=emb_size,
                                               initializer=lambda w:W,
                                               trainable=False)

它报告错误:

TypeError: () 得到了一个意外的关键字参数“dtype”

标签: pythontensorflowtensorflow-estimator

解决方案


我也在这里提出一个问题https://github.com/tensorflow/tensorflow/issues/20663

最后我找到了解决它的正确方法。虽然。我不清楚为什么上面的答案无效!如果你知道这个问题,谢谢给我一些建议!!

好的~~~~这是当前的解决方案。实际上从这里Feature Columns Embedding 查找

代码:

itemx_vocab = tf.feature_column.categorical_column_with_vocabulary_file(
    key='itemx',
    vocabulary_file=FLAGS.vocabx)

embedding_initializer_x = tf.contrib.framework.load_embedding_initializer(
    ckpt_path='model.ckpt',
    embedding_tensor_name='w_in',
    new_vocab_size=itemx_vocab.vocabulary_size,
    embedding_dim=emb_size,
    old_vocab_file='FLAGS.vocab_emb',
    new_vocab_file=FLAGS.vocabx
)
itemx_emb = tf.feature_column.embedding_column(itemx_vocab,
                                               dimension=128,
                                               initializer=embedding_initializer_x,
                                               trainable=False)

推荐阅读