python - keras.Model 中 Input 的附加维度从何而来?
问题描述
当我定义一个模型时:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
input_shape = (20,20)
input = tf.keras.Input(shape=input_shape)
nn = layers.Flatten()(input)
nn = layers.Dense(10)(nn)
output = layers.Activation('sigmoid')(nn)
model = tf.keras.Model(inputs=input, outputs=output)
为什么我需要在我的实际输入中添加另一个维度:
actual_input = np.ones((1,20,20))
prediction = model.predict(actual_input)
为什么我不能这样做actual_input = np.ones((20,20))
?
编辑:
在文档中,它说了一些关于batchsize 的信息。这个batchsize 是否与我的问题有关?如果是这样,当我想用我的模型进行预测时,我为什么需要它?谢谢你的帮助。
解决方案
在Keras
( TensorFlow
) 中,无法预测单个输入。因此,即使您只有一个示例,您也需要将其添加batch_axis
到其中。
实际上,在这种情况下,批处理大小为 1,因此批处理轴。
这就是构建方式TensorFlow
和Keras
构建方式,即使对于单个预测,您也需要添加批处理轴(批处理大小为 1 == 1 个单个示例)。
您可以使用np.expand_dims(input,axis=0)
或tf.expand_dims(input,axis=0)
将您的输入转换为适合预测的格式。
推荐阅读
- reactjs - 将反应站点发布到 github.io 时出错
- python - “str”对象没有属性“isocalendar”
- c# - SQL 语句上的空值包括对行的 SELECT 查询的结果(显示您发送的 SQL 查询和该查询的结果)
- javascript - Google API 如何从结果中获取一个变量
- themes - 流利的用户界面反应团队的主题
- python - Django Rest Framework - 在 Viewset 视图列表 API 上添加分页(限制对象数),而无需 Django 模型类
- php - 如何在同一页面上为每个 MySQL 用户提供自己的 FullCalendar?
- http - 为什么 webrtc 需要异步连接发送调用配置?
- python - 使用带有 Range 的 Header 后为空 content_length
- google-cloud-platform - 运行 GCP 代码时出现问题,它找不到表,因此不让运行