首页 > 解决方案 > Keras model.reset_states() 不适用于 tf.train.MonitoredTrainingSession

问题描述

我想tf.train.MonitoredTrainingSession()用于训练 Keras 中描述的模型。这个模型是一个有状态的模型,所以我想在每个 epoch 之后重置状态。

一个问题是,如果我调用model.reset_states(),它会产生以下错误。

RuntimeError:图表已完成,无法修改。

如果tf.Session()使用 代替tf.train.MonitoredTrainingSession(),则不会出现此错误。

例如,在以下示例代码中,即使它不是训练代码,也会生成相同的错误消息。

#!/usr/bin/python

import tensorflow as tf


inputs1 = tf.reshape(tf.linspace(0.0, 100.0, 10), (1, 2, 5)) 
inputs2 = tf.reshape(tf.linspace(100.0, 0.0, 10), (1, 2, 5)) 

model = tf.keras.Sequential([
    tf.keras.layers.LSTM(
    5,  
    return_sequences=True, stateful=True)
])

outputs1 = model(inputs1)
outputs2 = model(inputs2)

with tf.train.MonitoredTrainingSession() as sess:
  model.reset_states()
  print (sess.run(outputs1))
  model.reset_states()
  print (sess.run(outputs2))

我找到了两种方法来解决这个问题:

  1. tf.get_current_graph()._unsafe_unfinalize()在重置统计数据之前使用。

  2. 使用tf.Session()而不是tf.train.MonitoedTrainingSession().

但我认为这两种方法都不理想。您能否建议在这种情况下最好的解决方案是什么?

标签: tensorflowsessionkerasreset

解决方案


推荐阅读