python - 类型错误:__init__() 缺少 1 个必需的位置参数:LSTMCell 中的“单位”
问题描述
我正在尝试使用稳定的基线在强化学习问题中实施时间注意,但是,我一直在客户政策中收到提到的错误。我正在使用 TensorFlow 1.14 版。在我的 policy.py 中使用 LSTMCell 和来自 TensorFlow 的 RNN 类时,我还在初始化一个包装器以引起注意,但我不断收到以下错误。
Traceback (most recent call last):
File "run.py", line 60, in <module>
trainedModel = model_training(featureMatrix, config['env_name'], config['number_of_cpus'], config['total_training_timesteps'], config['policy'])
File "/code/src/util/utils.py", line 88, in model_training
trained_model = trained_model.train()
File "/code/src/util/model/model_training.py", line 103, in train
tensorboard_log=self.tensorboard_path).learn(total_timesteps=self.total_training_timesteps, callback=self.callback)
File "/venv/lib/python3.7/site-packages/stable_baselines/acktr/acktr.py", line 119, in __init__
self.setup_model()
File "/venv/lib/python3.7/site-packages/stable_baselines/acktr/acktr.py", line 148, in setup_model
1, n_batch_step, reuse=False, **self.policy_kwargs)
File "/code/src/util/policy/policy.py", line 97, in __init__
rnn = tf.keras.layers.RNN(self._build_rnn_cell())
File "/code/src/util/policy/policy.py", line 165, in _build_rnn_cell
return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
File "/code/src/util/policy/policy.py", line 165, in <listcomp>
return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
File "/code/src/util/policy/policy.py", line 158, in _build_single_cell
128,
File "/code/src/util/policy/attention_wrapper.py", line 123, in __init__
super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
TypeError: __init__() missing 1 required positional argument: 'units'
我的 policy.py 如下:
class CustomPolicy(ActorCriticPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=True)
with tf.variable_scope("model", reuse=reuse):
rnn = tf.keras.layers.RNN(self._build_rnn_cell())
feature_layer = rnn(self.processed_obs)
pi_layers = Sequential([
Dense(128, input_shape = (256,),
kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01)),
Activation('relu'),
Dense(128, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))
])
pi_latent = pi_layers(feature_layer)
vf_layers = Sequential([
Dense(32, input_shape = (256,),
kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01)),
Activation('relu'),
Dense(32, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))
])
vf_latent = vf_layers(feature_layer)
temp_value_fn = Dense(1, input_shape=(32,))
value_fn = temp_value_fn(vf_latent)
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)
self._value_fn = value_fn
self._setup_init()
def step(self, obs, state=None, mask=None, deterministic=False):
if deterministic:
action, value, neglogp = self.sess.run([self.deterministic_action, self.value_flat, self.neglogp],
{self.obs_ph: obs})
else:
action, value, neglogp = self.sess.run([self.action, self.value_flat, self.neglogp],
{self.obs_ph: obs})
return action, value, self.initial_state, neglogp
def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs})
def value(self, obs, state=None, mask=None):
return self.sess.run(self.value_flat, {self.obs_ph: obs})
def _build_single_cell(self):
cell = tf.keras.layers.LSTMCell(256)
cell = TemporalPatternAttentionCellWrapper(
cell,
128,
)
return cell
def _build_rnn_cell(self):
return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
我的注意力包装如下:
class TemporalPatternAttentionCellWrapper(tf.keras.layers.LSTMCell):
def __init__(self,
cell,
attn_length,
units=256,
attn_size=None,
attn_vec_size=None,
input_size=None,
state_is_tuple=True,
reuse=None):
"""Create a cell with attention.
Args:
cell: an RNNCell, an attention is added to it.
attn_length: integer, the size of an attention window.
attn_size: integer, the size of an attention vector. Equal to
cell.output_size by default.
attn_vec_size: integer, the number of convolutional features
calculated on attention state and a size of the hidden layer
built from base cell state. Equal attn_size to by default.
input_size: integer, the size of a hidden linear layer, built from
inputs and attention. Derived from the input tensor by default.
state_is_tuple: If True, accepted and returned states are n-tuples,
where `n = len(cells)`. By default (False), the states are all
concatenated along the column axis.
reuse: (optional) Python boolean describing whether to reuse
variables in an existing scope. If not `True`, and the existing
scope already has the given variables, an error is raised.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if cell returns a state tuple but the flag
`state_is_tuple` is `False` or if attn_length is zero or less.
"""
super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
if nest.is_sequence(cell.state_size) and not state_is_tuple:
raise ValueError("Cell returns tuple of states, but the flag "
"state_is_tuple is not set. State size is: %s" %
str(cell.state_size))
if attn_length <= 0:
raise ValueError("attn_length should be greater than zero, got %s"
% str(attn_length))
if not state_is_tuple:
logging.warn(
"%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if attn_size is None:
attn_size = 2880
if attn_vec_size is None:
attn_vec_size = attn_size
self._state_is_tuple = state_is_tuple
self._cell = cell
self._attn_vec_size = attn_vec_size
self._input_size = input_size
self._attn_size = attn_size
self._attn_length = attn_length
self._reuse = reuse
self._attention_mech = TemporalPatternAttentionMechanism()
@property
def state_size(self):
size = (self._cell.state_size, self._attn_size,
self._attn_size * self._attn_length)
if self._state_is_tuple:
return size
else:
return sum(list(size))
@property
def output_size(self):
return self._attn_size
def call(self, inputs, state):
"""Long short-term memory cell with attention (LSTMA)."""
print("TPA Wrapper called")
if self._state_is_tuple:
state, attns, attn_states = state
else:
states = state
state = tf.slice(states, [0, 0], [-1, self._cell.state_size])
attns = tf.slice(states, [0, self._cell.state_size],
[-1, self._attn_size])
attn_states = tf.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
attn_states = tf.reshape(attn_states,
[-1, self._attn_length, self._attn_size])
input_size = self._input_size
if input_size is None:
input_size = inputs.get_shape().as_list()[1]
temp_inputs = Dense(input_size, input_shape = (5760,), use_bias=True)
inputs = temp_inputs(tf.concat([inputs, attns], 1))
lstm_output, new_state = self._cell(inputs)
if self._state_is_tuple:
new_state_cat = tf.concat(nest.flatten(new_state), 1)
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention_mech(
new_state_cat, attn_states, self._attn_size, self._attn_length,
self._attn_vec_size)
with tf.variable_scope("attn_output_projection"):
temp_output = Sequential([
Dense(self._attn_size, input_shape = (2880,),
use_bias=True),
])
output = dense(tf.concat([lstm_output, new_attns], 1))
new_attn_states = tf.concat(
[new_attn_states, tf.expand_dims(output, 1)], 1)
new_attn_states = tf.reshape(new_attn_states,
[-1, self._attn_length * self._attn_size])
new_state = (new_state, new_attns, new_attn_states)
if not self._state_is_tuple:
new_state = tf.concat(list(new_state), 1)
return output, new_state
错误发生在该行
super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse) in the init function of the wrapper.
任何帮助将不胜感激,如果需要更多信息,请告诉我。
解决方案
根据 的文档,LSTMCell
它首先需要一个强制units
参数,即输出空间的维数。
当您__init__()
在错误行调用它时,您需要使用__init__(units, ...)
.
推荐阅读
- java - 如何将泛型类型转换为字符串
- c# - 如何从捕获的数据包中获取短信?
- php - 我的代码中的行计数方法没有返回受影响的行有什么问题?
- python - 如何使用 pandas 填写缺失的时间数据
- visual-c++ - 在BST中,我将一个节点复制到另一个节点并删除第一个节点并返回复制的节点,但复制的节点值是垃圾
- node.js - Termux - npm 错误!错误:EPERM:不允许操作
- discord.js - 将 DM 发送给 discord.js 中命令 arg 指定的用户
- kubernetes - 支持 SSL 中等强度密码套件 (SWEET32) - Nessus 插件 ID 42873
- java - 为什么使用接口引用进行片段间通信?
- facebook - 与 facebook 服务器(登录和信使部分)通信时使用的协议是什么?