python - Tensorflow model.fit() 似乎没有更新模型
问题描述
背景
我有一个自定义模型,可以评估Connect 4游戏中的位置。它有两个输出,一个动作值向量,logits 不是概率,表示给定动作的有利程度,以及一个表示游戏状态值的标量值。为了测试我的训练循环是否有效,我给它提供了一个仅包含一个状态和一个值的数据集,如果一切顺利,对数据集中状态的网络的评估应该收敛到数据集中的值。
def train(self, training_data, args):
"""Runs a model training step."""
# all boards and values in these lists are identical
boards, probs, values = list(zip(*training_data))
self.model.fit(np.stack(boards), [np.stack(probs), np.stack(values)], batch_size=args["batch_size"],
epochs=args["epochs"])
print(self.model.predict(boards[0].reshape(1,7,6,1))) # should output values[0]
我使用了两个损失函数,CrossEntropy
用于动作值的损失和MeanSquaredError
用于状态评估的损失。
def prob_loss(y_true, y_pred):
loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
return loss
def value_loss(y_true, y_pred):
mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
loss = mse(y_true, y_pred)
return loss
def compile_model(self):
losses = {'output_1':prob_loss, 'output_2':value_loss}
lossWeights={'output_1':0.5, 'output_2':0.5}
adam = keras.optimizers.Adam(learning_rate=self.params["alpha"])
self.model.compile(optimizer=adam, loss=losses, loss_weights=lossWeights, run_eagerly=True)
当我运行这个训练循环时,状态值的损失始终收敛到零,但令人惊讶的是,始终在状态上调用模型还远远不够。在我的例子中,状态值应该已经收敛到 -0.5,但相反,它稳定在 0.08。
我试图通过在调用中启用 run_eagerly 来弄清楚发生了什么model.compile
,并通过调试器,我发现在调用中model.fit()
,模型确实非常接近正确答案,并且准确地报告了损失。
def train(self, training_data, args):
"""Runs a model training step."""
boards, probs, values = list(zip(*training_data))
# state evaluation in this method is very accurate
self.model.fit(np.stack(boards), [np.stack(probs), np.stack(values)], batch_size=args["batch_size"],
epochs=args["epochs"])
# model call here outputs an extremely inaccurate answer
print(self.model.predict(boards[0].reshape(1,7,6,1)))
所以我的问题是,为什么 .fit() 方法中的状态评估准确,但随后的评估.predict()
不准确?
更新
我做了一些更多的挖掘,发现评估可以产生与 .fit() 方法不同的损失,特别是在存在丢失或批量标准化的情况下。我的模型确实有批量标准化,虽然我还不确定这是否是问题所在。来源:https ://github.com/keras-team/keras/issues/6977
更新
是的,批量标准化层导致了这个问题。删除它们立即提高了预测方法的准确性。
解决方案
推荐阅读
- powershell - 在 Powershell 中将 Emoji UTF8 转换为 Unicode
- firebase - BigQuery 中使用 Firebase 分析数据的每日计划
- python - VS Code 消耗大量内存。为什么?
- javascript - 如何使用 JQuery 和 JavaScript 制作 Tab 循环?
- reactjs - PropTypes oneOfType 是必需的
- java - 通过axis2 xml指定活动和故障转移jms连接
- audio - 如何解决以八度音阶执行 FM 解调的错误?
- angular - 如何在 Angular 8 中销毁 ngx-swiper-wrapper
- java - 除了迭代器对象之外,还有哪些不同的方法可以遍历向量?
- jruby - 从多个线程写入文件