tensorflow - 恢复检查点失败:检查点中找不到密钥
问题描述
我能够成功训练 RNN 并看到 Tensorboard 中出现准确度/损失。问题是当我尝试从检查点文件加载模型时,我收到以下错误:
Key fully_connected/Variable not found in checkpoint
[[node save/RestoreV2 (defined at train.py:87) = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
这是导致问题的代码(我省略了我认为不相关的部分):
tf.reset_default_graph()
with tf.name_scope('input_data'):
input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])
with tf.name_scope('labels'):
labels = tf.placeholder(tf.float32, [batchSize, numClasses])
with tf.name_scope('embeddings'):
data = tf.nn.embedding_lookup(wordVectors, input_data)
with tf.name_scope('lstm_layer'):
lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75)
with tf.name_scope('rnn_feed_forward'):
value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)
with tf.name_scope('fully_connected'):
weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))
bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))
with tf.name_scope('predictions'):
value = tf.transpose(value, [1, 0, 2])
last = tf.gather(value, int(value.get_shape()[0]) - 1)
prediction = (tf.matmul(last, weight) + bias)
with tf.name_scope('accuracy'):
correctPred = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))
accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))
with tf.name_scope('cost'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))
with tf.name_scope('train'):
optimizer = tf.train.AdamOptimizer().minimize(loss)
merged = tf.summary.merge_all()
saver = tf.train.Saver() # Saving and loading
# Train the model
print('Training has begun.')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.summary.scalar('Loss', loss)
tf.summary.scalar('Accuracy', accuracy)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(logdir, sess.graph)
for i in range(iterations):
nextBatch, nextBatchLabels = get_train_batch();
sess.run(optimizer, { input_data: nextBatch, labels: nextBatchLabels })
if (i % 50 == 0):
print('Entering iteration ' + str(i))
summary = sess.run(merged, {input_data: nextBatch, labels: nextBatchLabels})
writer.add_summary(summary, i)
if (i % 10000 == 0 and i != 0):
save_path = saver.save(sess, modelsDir, global_step=i)
print('Saved to %s' % save_path)
writer.close()
我的情况是,当我将优化器添加到会话中时,sess.run(optimizer...
我有效地将所有变量及其依赖变量添加到图表中。
关键“fully_connected”虽然是 name_scope,但我对它是如何被随机排除的有点困惑。
细节
该chkp.print_tensors_in_checkpoint_file("./models/pretrained_lstm.ckpt-10000", tensor_name='', all_tensors=True)
命令的输出为我提供了一堆名称不是很有用的变量:
Variable
Variable_1
Variable_1/Adam
Variable_1/Adam_1
etc.
现在我想知道这是否与我没有明确命名变量有关?现在正在尝试。
问题
有没有更多经验的人发现我做错了什么?你能启发我吗?
作为一个开放式问题,除了 Tensorboard(它没有帮助我解决这个问题,因为它实际上并没有读取检查点文件)之外,您会推荐哪些工具来检查会话和图表?
解决方案
没有错。Adam 优化器结合了 AdaGrad(自适应梯度)和 RMSProp(均方根传播)技术。后者跟踪每个参数的学习率,这些学习率根据当前梯度的移动平均值进行调整。
在恢复模型的情况下,重要的是该算法使用梯度的 EMA 和平方梯度,因此控制衰减率的内部变量 beta1 和 beta2 添加到层中。
您不能通过从字典中排除它们来恢复这些特定变量,您可以将其传递给saver.restore
您可以创建此字典
vars_to_restore = [i[0] for i in tf.train.list_variables(file.ckpt)]
restore_dict = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars_to_restore}
然后你只需要初始化亚当变量
tf.variables_initializer(optimizer.variables())
您可以使用这个简单的函数来检查检查点和当前图中的变量/范围名称。
推荐阅读
- java - Java 8 Streams 减少删除重复项,保留最新条目
- vue.js - Vuetify - 能够选择父节点而不打开它们
- python - 如何使此函数从 Python 中的数组行返回一对组合?
- php - 是否有反转“array_column”的内置/短 PHP 函数?
- tfs - TFS 构建步骤 - 执行 POST Web 请求
- reactjs - window.stripe 的事件监听器
- excel - 在单元格中添加减号
- sql - 在 SQL 中将 2 个结果合二为一
- bash - 在文件名的文件扩展名之后删除未知/非特定字符串
- android - Android Studio 工具栏标题更新错误,从一个片段返回堆栈到另一个片段