python - 我在 Keras 中正确理解了 batch_size 吗?
问题描述
我正在使用 Keras 的内置功能inception_resnet_v2
来训练 CNN 来识别图像。在训练模型时,我有一个 numpy 数据数组作为输入,输入形状为 (1000, 299, 299, 3),
model.fit(x=X, y=Y, batch_size=16, ...) # Output shape `Y` is (1000, 6), for 6 classes
起初,当试图预测时,我传入了一个形状(299、299、3)的图像,但得到了错误
ValueError:检查输入时出错:预期 input_1 有 4 个维度,但得到了形状为 (299、299、3) 的数组
我用以下方法重塑了我的输入:
x = np.reshape(x, ((1, 299, 299, 3)))
现在,当我预测时,
y = model.predict(x, batch_size=1, verbose=0)
我没有收到错误。
我想确保我batch_size
在训练和预测中都能正确理解。我的假设是:
1) 使用model.fit
,Kerasbatch_size
从输入数组中获取元素(在这种情况下,它一次可以处理我的 1000 个示例 16 个样本)
2) 使用model.predict
,我应该将输入重塑为单个 3D 数组,并且我应该明确设置batch_size
为 1。
这些是正确的假设吗?
此外,向模型提供训练数据是否会更好(甚至可能),这样就不需要在预测之前进行这种重塑?谢谢你帮助我学习这个。
解决方案
不,你的想法错了。batch_size
指定一次通过网络“转发”多少数据示例(通常使用 GPU)。
默认情况下,此值设置为32
insidemodel.predict
方法,但您可以指定其他方式(就像您对 所做的那样batch_size=1
)。由于这个默认值,你得到一个错误:
ValueError:检查输入时出错:预期 input_1 有 4 个维度,但得到了形状为 (299、299、3) 的数组
您不应该以这种方式重塑您的输入,而是为它提供正确的批量大小。
说,对于默认情况,您将传递一个 shape 数组(32, 299, 299, 3)
,类似于不同的batch_size
,例如,使用batch_size=64
此函数需要您传递 shape 的输入(64, 299, 299, 3
。
编辑:
看来你需要将你的单个样本重新塑造成一批。我建议您使用np.expand_dims
以提高代码的可读性和可移植性,如下所示:
y = model.predict(np.expand_dims(x, axis=0), batch_size=1)
推荐阅读
- sql - 从日期和时间计算时间
- c++ - 将 Gtk::ListStore 添加到 Gtk::TreeView 并将行添加到 ListStore (gtkmm)
- python - 按多个条件过滤数据框并在熊猫中创建新列
- c# - IBM MQ 服务器对消息进行编码
- mariadb - Mariadb insert error code 4025 约束失败问题
- angular - 在服务(ngOnInit 或构造函数内部)中获取数据的确切位置?
- javascript - SPARX 企业架构师。使用脚本为带有格式化文本的图表添加注释
- python-3.x - 使用python根据学生的百分比识别学生的详细信息
- angular - 单元测试中可观察的角度未覆盖线:
- sql - Oracle(SQL 语句)- 将 CASE 与 WHERE CLAUSE 结合使用