python - 带有 2D numpy 数组输入的 Tensorflow Keras Conv2D 错误
问题描述
我想使用 2D numpy 数组作为输入来训练 CNN,但我收到此错误:ValueError: Error when checking input: expected conv2d_input to have 4 dimensions, but got array with shape (21, 21)
.
我的输入确实是一个 21x21 的 numpy 浮点数组。网络的第一层被定义为Conv2D(32, (3, 3), input_shape=(21, 21, 1))
匹配输入数组的形状。
我发现了一些类似的问题,但与二维输入数组无关,它们主要处理图像。根据文档,Conv2D 期望输入包含 的 4D 张量(samples, channels, rows, cols)
,但我找不到任何解释这些值含义的文档。与图像输入有关的类似问题建议使用 重塑输入数组np.ndarray.reshape()
,但在尝试这样做时,我收到输入错误。
如何在这样的输入数组上训练 CNN?应该input_shape
是不同大小的元组吗?
解决方案
您当前的 numpy 数组具有维度(21, 21)
。但是,TensorFlow 期望输入张量具有格式(batch_size, height, width, channels)
或BHWC的维度,这意味着您需要将 numpy 输入数组转换为4维(从当前的2维)。一种方法如下:
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)
现在,numpyinput
数组具有维度:(1, 21, 21, 1)
可以传递给 TF Conv2D 操作。
希望这可以帮助!:)
推荐阅读
- reactjs - 如何将 console.log 消息从 Firefox/Chrome webbrowser 控制台重定向到终端
- excel - 更改下拉按钮的组合框颜色
- c++ - 如何拆分嵌入在 C++ 分隔符中的字符串?
- python - 如何将 Xaxis(次)的格式更改为其他格式?Python - 熊猫 - Matplotlib
- amazon-ses - 代表用户发送电子邮件
- node.js - 有没有办法在 Node.js 中获取已加载模块的文件路径?
- javascript - Aws4 签署 S3 PUT 请求
- powershell - 将(年份)保存在重命名文件夹中
- c# - 有没有一种更简洁的方法可以像字典一样访问列表?
- java - Apache Spark 在完全分布式模式下对 Executors 采取行动