python - Keras model.predict 分类标签的错误
问题描述
我正在尝试查看预测结果并使用 model.predict 函数打印它们,但出现错误:
ValueError: Error when checking model : the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [array([[array([ 0,....
我有多个输入,两者都是嵌入的。当我嵌入一个输入时,此代码以前可以工作。
for i in range(100):
prediction_result = model.predict(np.array([test_text[i], test_posts[i]]))
predicted_label = labels_name[np.argmax(prediction_result)]
print(text_data.iloc[i][:100], "")
print('Actual label:' + tags_test.iloc[i])
print("Predicted label: " + predicted_label + "\n")
test_text 和 test_posts 是 pad_sequences 的结果。它们在数组中,test_text 的形状为 100,test_posts 的形状为 1。labels_name 是标签的名称。我在第二行有错误;
prediction_result = model.predict(np.array([test_text[i], test_posts[i]]))
错误:
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
1815 x = _standardize_input_data(x, self._feed_input_names,
1816 self._feed_input_shapes,
-> 1817 check_batch_axis=False)
1818 if self.stateful:
1819 if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
84 'Expected to see ' + str(len(names)) + ' array(s), '
85 'but instead got the following list of ' +
---> 86 str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
87 elif len(names) > 1:
88 raise ValueError(
ValueError: Error when checking model : the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [array([[array([ 0, 0, ...
它看起来像一个简单的解决方案,但我找不到它。谢谢您的帮助。
解决方案
该模型需要两个数组,而您传递的是一个 numpy 数组。
prediction_result = model.predict([test_text.values[i].reshape(-1,100), test_posts.values[i].reshape(-1,1)])
删除调用 numpy.array 方法,你的错误就会消失。
更新:
没有必要使用for loop
.
prediction_result = model.predict([test_text.values.reshape(-1,100), test_posts.values.reshape(-1,1)])
这可以做你想做的事。prediction_result 现在的形状为(number rows in test_text,number of outputs)
推荐阅读
- android - Android 10 及更高版本中的 Wi-Fi 断开和重新连接
- android - 包 com.mikepenz.community_material_typeface_library 不存在
- tableau-desktop - 连接到 Tableau 中的存储过程时出错
- ios - iOS:使网站只能嵌入到我的应用程序中
- android - Android:如何从依赖图中删除 org.bouncycastle:bcpkix-jdk15on:1.56 并使用 org.bouncycastle:bcpkix-jdk15on:1.65
- javascript - 在 mac 上更新软件后,我无法运行 sudo npm i -g expo cli 命令
- python - 在 discord.py 中踢/禁止人
- spss-modeler - SPSS Modeler 实时预测服务
- c# - 在单独的类中删除带有按钮的文本框
- reactive-programming - 通过netty服务器进行keycloak身份验证的spring webflux