python - 使用自定义 RNN 层的自定义模型中缺少状态参数
问题描述
我正在 Tensorflow 2.1 中构建自己的层并在自定义模型中使用它。在下面的示例中,我从 tensorflow 网站 ( https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN ) 复制了 MinimalRNNCell 代码,并尝试在我的模型中使用该层。
但是,在尝试拟合模型时,我收到一条错误消息,指出单元格的调用方法需要“状态”参数,而我没有提供它。
我应该如何纠正我的模型以提供该论点?
我的代码:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras import Model
import numpy as np
class MinimalRNNCell(Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, [output]
class RNNXModel(Model):
def __init__(self, size):
super(RNNXModel, self).__init__()
self.minimalrnn=MinimalRNNCell(size)
def call(self, inputs):
out=self.minimalrnn(inputs)
return out
x=np.array([[[1,2,3],[4,5,6],[7,8,9]],[[10,11,12],[13,14,15],[16,17,18]]])
y=np.array([[1,2,3],[10,11,12]])
model=RNNXModel(3)
model.compile(optimizer='sgd', loss='mse')
model.fit(x,y,epochs=10, batch_size=1)
我得到的错误:
Traceback (most recent call last):
File "/home/.../test.py", line 64, in <module>
model.fit(x,y,epochs=10, batch_size=1)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 819, in fit
use_multiprocessing=use_multiprocessing)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 235, in fit
use_multiprocessing=use_multiprocessing)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 593, in _process_training_inputs
use_multiprocessing=use_multiprocessing)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 646, in _process_inputs
x, y, sample_weight=sample_weights)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 2346, in _standardize_user_data
all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 2572, in _build_model_with_inputs
self._set_inputs(cast_inputs)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 2659, in _set_inputs
outputs = self(inputs, **kwargs)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 773, in __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/autograph/impl/api.py", line 237, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: in converted code:
/home/.../test.py:36 call *
out=self.minimalrnn(inputs)
/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py:773 __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
TypeError: tf__call() missing 1 required positional argument: 'states'
解决方案
感谢Susmit Agrawal,我带来了这个并且它有效:
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, output
class RNNXModel(Model):
def __init__(self, size):
super(RNNXModel, self).__init__()
self.minimalrnn=RNN(MinimalRNNCell(size))
def call(self, inputs):
out=self.minimalrnn(inputs)
return out
x=np.array([[[1,2,3],[4,5,6],[7,8,9]],[[10,11,12],[13,14,15],[16,17,18]]])
y=np.array([[1,2,3],[10,11,12]])
model=RNNXModel(3)
model.compile(optimizer='sgd', loss='mse')
model.fit(x,y,epochs=10, batch_size=1)
推荐阅读
- javascript - Angular 8 NgRx - 错误:检测到不可序列化的操作
- javascript - 将 localStorage 输入数据传递到下一页
- rapids - 替换 C 列中的值,其中 A 列中的值为 x
- javascript - 从对象中的其他数组创建数组
- sql - 使用 SUM 作为检索表之一时,如何保留 LEFT OUTER JOIN 表中的所有值,为什么 GROUP BY 似乎可以解决问题?
- php - 如何编写将子范围转换为索引对的 A 到 Z 词索引算法?
- python - 将灰度图像分量颜色分割成黑色、白色和灰色
- python - 使用 pyspark 从 zipfile 读取 csv
- javascript - 使用谷歌应用在脚本编辑器中拆分
- html - 将 HTML (WTForm) 转换为静态 HTML 电子邮件的简单方法?