首页 > 解决方案 > 模型输出 "Tensor("activation_9/activation_9/Identity:0", shape=(?, 6), dtype=float32)" 的形状无效

问题描述

我在尝试构建 DQN 模型时收到此错误,但我收到此错误:

ValueError                                Traceback (most recent call last)
<ipython-input-41-42c80ec471c2> in <module>()
      1 # TODO - Select the parameters for the Agent and the Optimizer
      2 dqn = DQNAgent(model=model, nb_actions=nb_actions,
----> 3                memory=memory)
      4 dqn.compile(Adam(lr=.00025), metrics=['mae'])

/usr/local/lib/python3.7/dist-packages/rl/agents/dqn.py in __init__(self, model, policy, test_policy, enable_double_dqn, enable_dueling_network, dueling_type, *args, **kwargs)

ValueError: Model output "Tensor("activation_14/activation_14/Identity:0", shape=(?, 6), dtype=float32)" has invalid shape. DQN expects a model that has one dimension for each action, in this case 6.

我在 keras-rl 和 tensorflow 中的版本有一些问题,所以这些是我正在使用的版本:

张量流==1.13.1

Keras==2.2.4

keras-rl2==1.0.4

代码如下所示:

from __future__ import division

from PIL import Image
import numpy as np
import gym

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Convolution2D, Permute
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K

from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from rl.callbacks import FileLogger, ModelIntervalCheckpoint
INPUT_SHAPE = (84, 84)
WINDOW_LENGTH = 4

env_name = 'SpaceInvaders-v0'
env = gym.make(env_name)
nb_actions = env.action_space.n

input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE
model = Sequential()
model.add(Permute((2, 3, 1), input_shape=input_shape))

model.add(Convolution2D(32, (8, 8), strides=(4, 4)))
model.add(Activation('relu'))
model.add(Convolution2D(64, (4, 4), strides=(2, 2)))
model.add(Activation('relu'))
model.add(Convolution2D(64, (3, 3), strides=(1, 1)))
model.add(Activation('relu'))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)

dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory)

我正在使用 tensorflow==1.13 因为我也有以下_keras_shape问题DQNAgent

'Tensor' object has no attribute '_keras_shape'

有人可以告诉我我做错了什么吗?

标签: pythontensorflowkerasdqnkeras-rl

解决方案


推荐阅读