首页 > 解决方案 > 训练 keras 模型时,维度如何工作?

问题描述

得到:

    assert q_values.shape == (len(state_batch), self.nb_actions)
AssertionError
q_values.shape <class 'tuple'>: (1, 1, 10)
(len(state_batch), self.nb_actions) <class 'tuple'>: (1, 10)

来自 sarsa 代理的 keras-rl 库:

rl.agents.sarsa.SARSAAgent#compute_batch_q_values

    batch = self.process_state_batch(state_batch)
    q_values = self.model.predict_on_batch(batch)
    assert q_values.shape == (len(state_batch), self.nb_actions)

这是我的代码:

class MyEnv(Env):

    def __init__(self):
        self._reset()

    def _reset(self) -> None:
        self.i = 0

    def _get_obs(self) -> List[float]:
        return [1] * 20

    def reset(self) -> List[float]:
        self._reset()
        return self._get_obs()



    model = Sequential()
    model.add(Dense(units=20, activation='relu', input_shape=(1, 20)))
    model.add(Dense(units=10, activation='softmax'))
    logger.info(model.summary())

    policy = BoltzmannQPolicy()
    agent = SARSAAgent(model=model, nb_actions=10, policy=policy)

    optimizer = Adam(lr=1e-3)
    agent.compile(optimizer, metrics=['mae'])

    env = MyEnv()
    agent.fit(env, 1, verbose=2, visualize=True)

想知道是否有人可以向我解释应该如何设置尺寸以及它如何与库一起使用?我正在输入一个包含 20 个输入的列表,并希望输出为 10。

标签: pythonkeraskeras-rl

解决方案


此特定错误是由您的输入形状为 (1, 20) 引起的。如果您使用 (20,) 的输入形状,错误将消失。

换句话说SARSAAgent,需要一个输出二维张量(batch_size,nb_actions)的模型。并且您的模型正在输出 (batch_size, 1, 10) 的形状。您可以减少模型输入中的尺寸或展平输出。


推荐阅读