tensorflow - TimeDistributed:reshape 的输入是一个有 265000 个值的张量,但请求的形状需要 800 的倍数
问题描述
我正在根据本教程创建一个 LSTM 模型,如下所示:
problem = self.hparams.problem
encoders = problem.feature_encoders
inputs_vocab_size = len(encoders['inputs'].subwords)
targets_vocab_size = len(encoders['targets'].subwords)
hidden_size = self.hparams.model.hidden_size
max_inputs_length = self.hparams.model.max_input_length
max_output_length = self.hparams.model.max_target_length
inputs = keras.Input(shape=(max_inputs_length,))
x = inputs
x = layers.Embedding(inputs_vocab_size, hidden_size, input_length=max_inputs_length, mask_zero=True)(x)
x = layers.LSTM(hidden_size)(x)
x = layers.RepeatVector(max_output_length)(x)
x = layers.LSTM(hidden_size, return_sequences=True)(x)
# Output modality
outputs = layers.TimeDistributed(layers.Dense(targets_vocab_size, activation='softmax'))(x)
self.keras_model = keras.Model(inputs=inputs, outputs=outputs)
self.keras_model.summary()
在训练期间,模型的损失计算如下:
def loss(self, logits, targets):
labels = tf.one_hot(targets, self.vocab_size)
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
return tf.reduce_mean(loss)
其中logits
是模型输出,targets
是训练样本。
但是,我在执行过程中遇到以下异常:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 265000 values, but the requested shape requires a multiple of 800
[[{{node model_fn/model/time_distributed/Reshape_1}}]]
显然我的TimeDistributed
图层有问题,但我不太明白问题出在哪里。与教程相比,-valued 张量来自哪里,265000
我在做什么不同?
模型摘要
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 100)] 0
_________________________________________________________________
embedding (Embedding) (None, 2, 64) 512
_________________________________________________________________
lstm (LSTM) (None, 64) 33024
_________________________________________________________________
repeat_vector (RepeatVector) (None, 100, 64) 0
_________________________________________________________________
lstm_1 (LSTM) (None, 100, 64) 33024
_________________________________________________________________
time_distributed (TimeDistri (None, 100, 8) 520
=================================================================
Total params: 67,080
Trainable params: 67,080
Non-trainable params: 0
_________________________________________________________________
解决方案
推荐阅读
- django - 添加或更改相关名称参数
- firebase - 根据引用文档的字段查询
- python - 如何转换现有的多模块 python 脚本以在 AWS Lambda 中使用?
- java - 一次从 2 个进程接收来自麦克风的输入
- apache-camel - 使用带有代理和 sslcontext 的 Camel-http4 时出现 HTTP 401 错误
- android - 如何使用 navGraph 范围初始化 viewModel
- python - 将变量从无重新分配到值python
- sql-server - sp_xp_cmdshell_proxy_account 为多个需要在没有系统管理员的情况下运行 xp_cmdshell 的用户提供权限
- machine-learning - 我们如何在没有 bpe 的情况下使用 fairseq 的翻译功能
- amazon-web-services - AWS 控制台到 github