首页 > 解决方案 > 简单的 Keras 函数(K.function)不起作用

问题描述

我正在使用最新版本的 Keras/Tensorflow,但出现此错误:

predict_fcn = K.function(model.inputs, model.outputs)
predict_fcn(input_values_test)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-17-8d6b6ae9c940> in <module>()
----> 1 predict_fcn(input_values_test)

~/Python_Libraries/keras/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2664                 return self._legacy_call(inputs)
2665 
-> 2666             return self._call(inputs)
2667         else:
2668             if py_any(is_tensor(x) for x in inputs):

~/Python_Libraries/keras/keras/backend/tensorflow_backend.py in _call(self, inputs)
2634                                 symbol_vals,
2635                                 session)
-> 2636         fetched = self._callable_fn(*array_vals)
2637         return fetched[:len(self.outputs)]
2638 

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args)
1452         else:
1453           return tf_session.TF_DeprecatedSessionRunCallable(
-> 1454               self._session._session, self._handle, args, status, None)
1455 
1456     def __del__(self):

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    517             None, None,
    518             compat.as_text(c_api.TF_Message(self.status.status)),
--> 519             c_api.TF_GetCode(self.status.status))
    520     # Delete the underlying status object from memory otherwise it stays alive
    521     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: ConcatOp : Expected concatenating dimensions in the range [-2, 2), but got 2
    [[Node: Merge_Embeddings/concat = ConcatV2[N=8, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:0"](0_Embedding/embedding_lookup, 1_Embedding/embedding_lookup, 2_Embedding/embedding_lookup, 3_Embedding/embedding_lookup, 4_Embedding/embedding_lookup, 5_Embedding/embedding_lookup, 6_Embedding/embedding_lookup, Embedding_Average/Mean, Merge_Embeddings/concat/axis)]]
    [[Node: Outputs/BiasAdd/_989 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1274_Outputs/BiasAdd", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

但是,如果简单地这样做model.predict(input_values_test)就可以完美地工作。阅读错误消息,似乎将模型作为函数而不是作为模型执行时的连接行为不同?

标签: pythontensorflowkeras

解决方案


推荐阅读