首页 > 解决方案 > 如何修复 get_updates 函数中的“形状必须相等等级”?

问题描述

我正在按照 bstriner 的要点(https://gist.github.com/bstriner/e1e011652b297d13b3ac3f99fd11b2bc#gistcomment-2310228)在 Keras 中为 REINFORCE RL 算法设置自定义损失和自定义拟合方法,但我得到了上述错误.

有什么提示吗?

我正在尝试将 Maxim Lapan 的 Pytorch REINFORCE RL 代码调整为 Keras。然而,y_pred 和 y_true 自定义损失 Keras 要求并没有让我达到这个完美的 PyTorch/Keras 端口......

Class NN():
...
    def custom_loss(self, y_pred, y_true):
        log_prob = self.log_softmax(y_pred)
        log_prob_qvals = self.batch_qvals *     log_prob[range(len(self.batch_states)), y_true]
        loss = K.mean(log_prob_qvals * -1, axis=0)
        return loss

    def custom_fit(self, x, ytrue):
        print('FFP-BP = 1 gradient update')
        updates = self.optim.get_updates(
        loss=self.custom_loss, params=K.variable(self.net.trainable_weights))
        return K.function(input=[x, ytrue], outputs=[self.custom_loss], updates=updates)
...

当我用 Pytorch 的代码按 F5 时,我的预期结果是匹配(几乎)相同的结果:

    logits_v = net(states_v)
    log_prob_v = F.log_softmax(logits_v, dim=1)
    log_prob_actions_v = batch_qvals_v * \
            log_prob_v[range(len(batch_states)), batch_actions_t]
    loss_v = -log_prob_actions_v.mean()
    loss_v.backward()
    optimizer.step()

编辑 1:在检查时,在 Python 调试时,我可以看到正确计算了损失的值。所以,我怀疑这与参数有关...

标签: pythonkerasreinforcement-learning

解决方案


推荐阅读