首页 > 解决方案 > Keras 顺序 API:在多个输入上调用 model.predict 时出现问题

问题描述

我在使用顺序 API 定义的模型上调用 model.predict 时遇到问题。使用功能 API 似乎可以正常工作。也许这与顺序/功能无关,我只是有一个错误。

使用带有多个输入的 keras 顺序 API,我有这个玩具模型:

left_branch = keras.models.Sequential()
left_branch.add(keras.layers.Dense(32, input_dim=784))

right_branch = keras.models.Sequential()
right_branch.add(keras.layers.Dense(32, input_dim=784))

merged = keras.layers.Concatenate([left_branch, right_branch])

final_model = keras.models.Sequential()
final_model.add(merged)
final_model.add(keras.layers.Dense(10, activation='softmax'))

当我这样调用 model.predict 时:

x1 = np.random.random_sample(size = [1,784])
x2 = np.random.random_sample(size = [1,784])
final_model.predict([x1,x2])

我收到错误:AttributeError: 'list' object has no attribute 'shape',即使两个列表项都是 np 数组。完整的错误如下。

当我在功能 API 中编写相同的模型时,不会发生错误,如下所示:

x1 = keras.layers.Input(shape=(784,))
left_branch = keras.layers.Dense(32)(x1)

x2 = keras.layers.Input(shape=(784,))
right_branch = keras.layers.Dense(32)(x2)

merged = keras.layers.Concatenate()([x1, x2])
out = keras.layers.Dense(10, activation='softmax')(merged)

final_model = keras.models.Model(inputs=[x1, x2], outputs=out)

这是顺序模型的完整错误输出:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-41-4b45e1bec6dd> in <module>()
      1 x1 = np.random.random_sample(size = [1,784])
      2 x2 = np.random.random_sample(size = [1,784])
----> 3 final_model.predict([x1,x2])

/Library/Python/2.7/site-packages/tensorflow/python/keras/engine/training.pyc in predict(self, x, batch_size, verbose, steps)
   1750     # Validate and standardize user data.
   1751     x, _, _ = self._standardize_user_data(
-> 1752         x, check_steps=True, steps_name='steps', steps=steps)
   1753 
   1754     if context.executing_eagerly():

/Library/Python/2.7/site-packages/tensorflow/python/keras/engine/training.pyc in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split)
    991       x, y = next_element
    992     x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
--> 993                                                      class_weight, batch_size)
    994     return x, y, sample_weights
    995 

/Library/Python/2.7/site-packages/tensorflow/python/keras/engine/training.pyc in _standardize_weights(self, x, y, sample_weight, class_weight, batch_size)
   1027       if not self.inputs:
   1028         is_build_called = True
-> 1029         self._set_inputs(x)
   1030 
   1031     if y is not None:

/Library/Python/2.7/site-packages/tensorflow/python/training/checkpointable/base.pyc in _method_wrapper(self, *args, **kwargs)
    424     self._setattr_tracking = False  # pylint: disable=protected-access
    425     try:
--> 426       method(self, *args, **kwargs)
    427     finally:
    428       self._setattr_tracking = previous_value  # pylint: disable=protected-access

/Library/Python/2.7/site-packages/tensorflow/python/keras/engine/training.pyc in _set_inputs(self, inputs, training)
   1219         self.build(input_shape=input_shape)
   1220       else:
-> 1221         input_shape = (None,) + inputs.shape[1:]
   1222         self.build(input_shape=input_shape)
   1223     if context.executing_eagerly():

AttributeError: 'list' object has no attribute 'shape'

标签: pythontensorflowkeras

解决方案


推荐阅读