首页 > 解决方案 > Tensorflow model.fit() 似乎没有更新模型

问题描述

背景

我有一个自定义模型,可以评估Connect 4游戏中的位置。它有两个输出,一个动作值向量,logits 不是概率,表示给定动作的有利程度,以及一个表示游戏状态值的标量值。为了测试我的训练循环是否有效,我给它提供了一个仅包含一个状态和一个值的数据集,如果一切顺利,对数据集中状态的网络的评估应该收敛到数据集中的值。

def train(self, training_data, args):
    """Runs a model training step."""
    # all boards and values in these lists are identical
    boards, probs, values = list(zip(*training_data))
    self.model.fit(np.stack(boards), [np.stack(probs), np.stack(values)], batch_size=args["batch_size"], 
    epochs=args["epochs"])

    print(self.model.predict(boards[0].reshape(1,7,6,1))) # should output values[0]

我使用了两个损失函数,CrossEntropy用于动作值的损失和MeanSquaredError用于状态评估的损失。

def prob_loss(y_true, y_pred):
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    return loss

def value_loss(y_true, y_pred):
    mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    loss = mse(y_true, y_pred)
    return loss

def compile_model(self):
    losses = {'output_1':prob_loss, 'output_2':value_loss}
    lossWeights={'output_1':0.5, 'output_2':0.5}
    adam = keras.optimizers.Adam(learning_rate=self.params["alpha"])
    self.model.compile(optimizer=adam, loss=losses, loss_weights=lossWeights, run_eagerly=True)

当我运行这个训练循环时,状态值的损失始终收敛到零,但令人惊讶的是,始终在状态上调用模型还远远不够。在我的例子中,状态值应该已经收敛到 -0.5,但相反,它稳定在 0.08。

我试图通过在调用中启用 run_eagerly 来弄清楚发生了什么model.compile,并通过调试器,我发现在调用中model.fit(),模型确实非常接近正确答案,并且准确地报告了损失。

def train(self, training_data, args):
        """Runs a model training step."""
        boards, probs, values = list(zip(*training_data))

        # state evaluation in this method is very accurate
        self.model.fit(np.stack(boards), [np.stack(probs), np.stack(values)], batch_size=args["batch_size"], 
        epochs=args["epochs"])
        
        # model call here outputs an extremely inaccurate answer
        print(self.model.predict(boards[0].reshape(1,7,6,1)))

所以我的问题是,为什么 .fit() 方法中的状态评估准确,但随后的评估.predict()不准确?

更新

我做了一些更多的挖掘,发现评估可以产生与 .fit() 方法不同的损失,特别是在存在丢失或批量标准化的情况下。我的模型确实有批量标准化,虽然我还不确定这是否是问题所在。来源:https ://github.com/keras-team/keras/issues/6977

更新

是的,批量标准化层导致了这个问题。删除它们立即提高了预测方法的准确性。

标签: pythontensorflow

解决方案


推荐阅读