python - Keras 错误:[0, 0] 中的预期大小 [1],但得到了 1
问题描述
我正在尝试seq2seq
在 Keras 的更大模型中构建解码器,但是当我运行 fit 函数时,我不断收到以下错误。否则模型构建得很好。
InvalidArgumentError: Expected size[1] in [0, 0], but got 1
[[Node: lambda_2/Slice = Slice[Index=DT_INT32, T=DT_FLOAT,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](lambda_1/Slice,
metrics/acc/Const, lambda_2/Slice/size)]]
lambda_x/Slice
似乎是指循环中的 lambda 函数。
我的模型有 4 个 shape 输入(N, 11), (N, 3), (N, 11), (N, 3)
并输出 shape 的 softmax 分布(N, 11, 1163)
。
下面是我的解码器代码,这是使用拆分器层的地方:
def _decoder_serial_input(self, encoder_states, state_h, state_c):
"""
Compute one-by-one input to decoder, taking output from previous time-step as input
:param encoder_states: All the encoder states
:param state_h: starting hidden state
:param state_c: starting cell state
:return: Concatenated output which is shape = (N, Timestep, Input dims)
"""
all_outputs = []
states = [state_h, state_c]
inputs = self.decoder_inputs # Shape = N x num_timestep
repeat = RepeatVector(1, name="decoder_style")
conc_1 = Concatenate(axis=-1, name="concatenate_decoder")
conc_att = Concatenate(axis=-1, name="concatenate_attention")
for t in range(self.max_timestep):
# This slices the input. -1 is to accept everything in that dimension
inputs = Lambda(lambda x: K.slice(x, start=[0, t], size=[-1, 1]))(inputs)
embedding_output = self.embedding_decoder(inputs)
style_labels = repeat(self.decoder_style_label)
concat = conc_1([embedding_output, style_labels]) # Join to style label
decoder_output_forward, state_h, state_c = self.decoder(concat, initial_state=states)
if self.attention:
context, _ = self._one_step_attention(encoder_states, state_h) # Size of latent dims
decoder_output_forward = conc_att([context, decoder_output_forward])
outputs = self.decoder_softmax_output(decoder_output_forward) # Shape = (N, 1, input dims)
all_outputs.append(outputs)
states = [state_h, state_c]
return Concatenate(axis=1, name="conc_dec_forward")(all_outputs)
有谁知道我为什么会收到这个错误?谢谢。
解决方案
我解决了这个问题。问题是我将Lambda
图层的输出设置inputs
为错误的变量。这将输入张量的形状更改为 lambda 层。在第一次迭代中,它是(N, 11)
,但在循环的后续迭代中,它变成了(N, 1)
,这导致了错误。
推荐阅读
- mysql - MySQL - 按分组结果分组
- amazon-web-services - 从 S3 存储桶下载数百万个文件
- javascript - Vue.js 应用程序有时仅在移动设备上按住按钮一会儿时才会响应
- python - 列表理解从元组列表中提取多个字段
- highcharts - Highcharts 向下钻取 如何保持向下钻取大小相同
- c++ - 将图像添加到资产文件夹 UWP C++
- android - 生成的 apk 和调试模式的 Gson 不同行为
- content-management-system - directus cms如何加入
- r - 从 Shinyalert 回调重新运行响应式
- javascript - Javascript 对象内部自引用