python - 我可以继续上一课的训练吗?
问题描述
我正在编写文本摘要代码。我想使用我在上一课中获得的最佳模型继续上一课的训练。我正在使用 TensorFlow 版本 1.15.0 和 google colaboratory 来执行笔记本。
learning_rate_decay = 0.95
min_learning_rate = 0.0005
display_step = 20 # Check training loss after every 20 batches
stop_early = 0
stop = 3 # If the update loss does not decrease in 3 consecutive update checks, stop training
per_epoch = 3 # Make 3 update checks per epoch
update_check = (len(sorted_texts)//batch_size//per_epoch)-1
update_loss = 0
batch_loss = 0
summary_update_loss = [] # Record the update losses for saving improvements in the model
checkpoint = "./best_model.ckpt"
with tf.Session(graph=train_graph) as sess:
sess.run(tf.global_variables_initializer())
# If we want to continue training a previous session
loader = tf.train.import_meta_graph(checkpoint + '.meta')
loader.restore(sess, checkpoint)
for epoch_i in range(1, epochs+1):
update_loss = 0
batch_loss = 0
for batch_i, (summaries_batch, texts_batch, summaries_lengths, texts_lengths) in enumerate(
get_batches(sorted_summaries, sorted_texts, batch_size)):
start_time = time.time()
_, loss = sess.run(
[train_op, cost],
{input_data: texts_batch,
targets: summaries_batch,
lr: learning_rate,
summary_length: summaries_lengths,
text_length: texts_lengths,
keep_prob: keep_probability})
batch_loss += loss
update_loss += loss
end_time = time.time()
batch_time = end_time - start_time
if batch_i % display_step == 0 and batch_i > 0:
print('Epoch {:>3}/{} Batch {:>4}/{} - Loss: {:>6.3f}, Seconds: {:>4.2f}'
.format(epoch_i,
epochs,
batch_i,
len(sorted_texts) // batch_size,
batch_loss / display_step,
batch_time*display_step))
batch_loss = 0
if batch_i % update_check == 0 and batch_i > 0:
print("Average loss for this update:", round(update_loss/update_check,3))
summary_update_loss.append(update_loss)
# If the update loss is at a new minimum, save the model
if update_loss <= min(summary_update_loss):
print('New Record!')
stop_early = 0
saver = tf.train.Saver()
saver.save(sess, checkpoint)
else:
print("No Improvement.")
stop_early += 1
if stop_early == stop:
break
update_loss = 0
# Reduce learning rate, but not below its minimum value
learning_rate *= learning_rate_decay
if learning_rate < min_learning_rate:
learning_rate = min_learning_rate
if stop_early == stop:
print("Stopping Training.")
break
在执行这段代码时,我在输出中收到此错误:
INFO:tensorflow:Restoring parameters from ./best_model.ckpt
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1119 subfeed_t = self.graph.as_graph_element(
-> 1120 subfeed, allow_tensor=True, allow_operation=False)
1121 except Exception as e:
4 frames
ValueError: Tensor Tensor("input:0", shape=(?, ?), dtype=int32) is not an element of this graph.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1121 except Exception as e:
1122 raise TypeError('Cannot interpret feed_dict key as Tensor: ' +
-> 1123 e.args[0])
1124
1125 if isinstance(subfeed_val, ops.Tensor):
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input:0", shape=(?, ?), dtype=int32) is not an element of this graph.
解决方案
推荐阅读
- python - 没有得到 django 密码重置 password_reset_confirm 预期错误
- html - 如何使转换后的排版响应图像
- react-native - 导航到屏幕时的数据更新
- javascript - 将一组图像添加到弹出框
- php - JWT Key -> exp 声明向我显示一个异常
- glsl - GLSL 线型 - 改变颜色
- javascript - 计算一个元素在多个其他数组 JavaScript 中的出现次数
- powershell - PowerShell Cmdlet 开发:关于在 Cmdlet 之间通过管道传输的 IEnumerable 的最佳实践
- html - 在不使用absolute或flex的情况下将div定位到父级的右下角
- alpha-vantage - 全球报价中的 Alpha Vantage 货币