python - 无法在训练前提供预训练词嵌入
问题描述
我想在训练之前加载预训练词嵌入,而不是在每个 train_steps 加载它。我按照这篇文章中的步骤进行操作。但它会显示错误:
您必须使用 dtype float 和 shape [2000002,300] 为占位符张量“word_embedding_placeholder”提供一个值
这是大致的代码:
embeddings_var = tf.Variable(tf.random_uniform([vocabulary_size, embedding_dim], -1.0, 1.0), trainable=False)
embedding_placeholder = tf.placeholder(tf.float32, [vocabulary_size, embedding_dim], name='word_embedding_placeholder')
embedding_init = embeddings_var.assign(embedding_placeholder) # assign exist word embeddings
batch_embedded = tf.nn.embedding_lookup(embedding_init, batch_ph)
sess = tf.Session()
train_steps = round(len(X_train) / BATCH_SIZE)
train_iterator, train_next_element = get_dataset_iterator(X_train, y_train, BATCH_SIZE, training_epochs)
sess.run(init_g)
sess.run(train_iterator.initializer)
_ = sess.run(embedding_init, feed_dict={embedding_placeholder: w2v})
for epoch in range(0, training_epochs):
# Training steps
for i in range(train_steps):
X_train_input, y_train_input = sess.run(train_next_element)
seq_len = np.array([list(word_idx).index(PADDING_INDEX) if PADDING_INDEX in word_idx else len(word_idx) for word_idx in X_train_input]) # actual lengths of sequences
train_loss, train_acc, _ = sess.run([loss, accuracy, optimizer],
feed_dict={batch_ph: X_train_input,
target_ph: y_train_input,
seq_len_ph: seq_len,
keep_prob_ph: KEEP_PROB})
当我将训练中的 feed_dict 更改为:
train_loss, train_acc, _ = sess.run([loss, accuracy, optimizer],
feed_dict={batch_ph: X_train_input,
target_ph: y_train_input,
seq_len_ph: seq_len,
keep_prob_ph: KEEP_PROB,
embedding_placeholder: w2v})
它有效,但并不优雅。有人遇到这个问题吗?
目标:我只想在训练前加载一次预训练嵌入。而不是每次都重新计算 embedding_init。
解决方案
大概您在网络中的某个地方使用了 batch_embedded,这意味着它被用于您的损失。这意味着每当您在循环内对 loss 执行 sess.run 时,您都在重新计算 batch_embedded,因此重新计算 embedding_init,您需要 embedding_placeholder。相反,您可以按如下方式初始化变量:
embeddings_var = tf.get_variable("embeddings_var", shape=[vocabulary_size, embedding_dim], initializer=tf.constant_initializer(w2v), trainable=False)
推荐阅读
- python - Django没有名为'main'的模块
- javascript - 如何测量 Qml 应用程序启动时间?/开机持续时间?
- java - 重构如果抛出一些异常则返回 false 的方法
- java - 如何将存储在移动设备中的录制视频的 Uri 传递给另一个活动
- gitlab - GitLab 中的“无法在 struct ObjectMeta 中为 json 字段“数据”找到 api 字段”的错误是什么?
- javascript - 在 AUI 脚本中序列化 serialize() 参数
- vue.js - vue中CKEditor 4中的html5audio插件
- python-3.x - 如何处理在 Django 中使用表单上传的文件?
- android - Kotlin Flow“首次离线”方法
- c++ - 什么是 Cmake 文件以及为什么我们在拥有 VIsual Studio 时使用它