首页 > 解决方案 > 连接张量列表时出现 Keras 错误

问题描述

我正在复制我认为是相同代码的两个版本,但其中一个有效,另一个无效:

from tensorflow.keras.layers import Dense, Input, Lambda, concatenate
from tensorflow.keras.models import Model


inp = Input(shape=(9,))

# Version 1 (works)
out_1 = Dense(1)(Lambda(lambda x: x[:,0:4])(inp))
out_2 = Dense(1)(Lambda(lambda x: x[:,4:9])(inp))
out = concatenate([out_1, out_2])
model = Model(inp, out)
model.compile(...)
model.fit(...) ✓


# Version 2 (doesn't work)
out = concatenate([Dense(1)(Lambda(lambda x: x[:,i:j])(inp)) for i, j in [(0, 4), (4, 9)]]) # concatenating with a list comprehension
model = Model(inp, out)
model.compile(...)
model.fit(...) ✗

错误信息是:

ValueError: Input 0 of layer dense_2 is incompatible with the layer: expected axis -1 of input shape to have value 4 but received input with shape (None, 5)

我不确定这是代码上的错误还是错误,但看起来连接在使用列表推导时混合了张量。帮助赞赏:)

澄清一下,这两种情况的compilefit功能是相同的:

import numpy as np

X = np.random.uniform(0, 1, (100, 9))
Y = np.random.uniform(0, 1, (100, 2))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X, Y)

标签: pythontensorflowkeras

解决方案


ValueError: Input 0 of layer dense_2 is incompatible with the layer: expected axis -1 of input shape to have value 4 but received input with shape (None, 5)

我已经清楚地告诉你,你使用的连接在错误的轴上。您必须将其连接在第一个轴 (0) 而不是第二个轴上。 value 4表示它正在寻找第二个轴,即 [0:4],而您正在连接 [4:9](即 5)。尝试更改轴参数,一切都会好起来的。

> 
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 9)]          0                                            
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 4)            0           input_2[0][0]                    
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 5)            0           input_2[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 1)            5           lambda_6[0][0]                   
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 1)            6           lambda_7[0][0]                   
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 2)            0           dense_6[0][0]                    
                                                                 dense_7[0][0]                    
==================================================================================================
Total params: 11
Trainable params: 11
Non-trainable params: 0
__________________________________________________________________________________________________
WARNING:tensorflow:From /home/dtlam26/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 100 samples
100/100 [==============================] - 1s 10ms/sample - loss: 4.1017 - acc: 0.0000e+00

第一个版本将连接 2 层,因此默认情况下它不会错误地通过 axis=-1 连接。


推荐阅读