python - 转换为估计器时的 LSTM InvalidArgumentError Tensorflow 2.0/Keras
问题描述
我正在尝试构建一个LSTM
接收一系列单词并将其转换为嵌入向量的网络。我已经将每个单词序列转换为词汇向量。
我使用的批量大小是 32,每个词汇向量的大小是 50。这是我迄今为止创建模型并将其转换为估计器的 Keras 功能 API 代码
input_layer = keras.layers.Input(shape=(50,), name='search')
embedding_layer = keras.layers.Embedding(input_dim=32, output_dim=256, input_length=50)(input_layer)
lstm_layer = keras.layers.LSTM(units=256)(embedding_layer)
model = keras.models.Model(inputs=input_layer, outputs=lstm_layer)
model.compile(loss='mean_squared_error', optimizer='adam')
estimator = keras.estimator.model_to_estimator(keras_model=model)
但是这段代码给出了错误
tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'Adam/gradients/lstm/StatefulPartitionedCall_grad/StatefulPartitionedCall': Connecting to invalid output 5 of source node lstm/StatefulPartitionedCall which has 5 outputs
当我运行时model.summary()
,这是输出
Layer (type) Output Shape Param #
=================================================================
search (InputLayer) [(None, 50)] 0
_________________________________________________________________
embedding (Embedding) (None, 50, 256) 8192
_________________________________________________________________
lstm (LSTM) (None, 256) 525312
=================================================================
Total params: 533,504
Trainable params: 533,504
Non-trainable params: 0
_________________________________________________________________
我认为这是我所期望的。我尝试用LSTM
相同形状的 Dense 和 Flatten 层替换该层,并且代码工作正常
解决方案
我自己来回答这个问题……截至 7 月 24 日,tf.keras.layers.LSTM 似乎存在问题,如此处所示。我将模型更改为以下
input_layer = keras.layers.Input(shape=(50,), name='search')
embedding_layer = keras.layers.Embedding(input_dim=32, output_dim=256,
input_length=50)(input_layer)
lstm_layer = keras.layers.RNN(cell=keras.layers.LSTMCell(units=256))(embedding_layer)
model = keras.models.Model(inputs=input_layer, outputs=lstm_layer)
model.compile(loss='mean_squared_error', optimizer='adam')
estimator = keras.estimator.model_to_estimator(keras_model=model)
推荐阅读
- html - Web.Config 重定向到维护页面?
- build - 使用 arm64v8/python:3.7-slim-buster 基础映像在 Dockerfile 中运行“apt-get update”时出错
- javascript - 以不同的时间间隔解析不同的JSON,写一个大的JSON
- javascript - 改进链式正则表达式以替换元音
- java - 为什么杰克逊的二传手只是反序列化?
- python - Python 的 yfinance 和 yahoo_fin 最近是否停止工作?
- excel - 使用 VBA 自动填充具有不同值的列
- .net-core - 我的单例类 .NET Core 中的依赖注入
- php - Magento 2 Block 模板在本地工作,不在服务器上
- c# - 通过.net上的字符串数组elasticsearch查询