python - Tensorflow - 对多个图像进行批量预测
问题描述
我有一个faces
列表,其中列表的每个元素都是一个形状为 (1, 224, 224, 3) 的 numpy 数组,即人脸图像。我有一个模型,其输入形状为(None, 224, 224, 3)
,输出形状为(None, 2)
.
现在我想对faces
列表中的所有图像进行预测。当然,我可以遍历列表并逐个获得预测,但我想将所有图像作为一批处理,只使用一次调用来model.predict()
更快地获得结果。
如果我像现在这样直接传递面孔列表(最后的完整代码),我只会得到第一张图像的预测。
print(f"{len(faces)} faces found")
print(faces[0].shape)
maskPreds = model.predict(faces)
print(maskPreds)
输出:
3 faces found
(1, 224, 224, 3)
[[0.9421933 0.05780665]]
但是maskPreds
对于 3 个图像应该是这样的:
[[0.9421933 0.05780665],
[0.01584494 0.98415506],
[0.09914105 0.9008589 ]]
完整代码:
from tensorflow.keras.models import load_model
from cvlib import detect_face
import cv2
import numpy as np
def detectAllFaces(frame):
dets = detect_face(frame)
boxes = dets[0]
confidences = dets[1]
faces = []
for box, confidence in zip(boxes, confidences):
startX, startY, endX, endY = box
cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 1)
face = frame[startY:endY, startX:endX]
face = cv2.resize(face, (224, 224))
face = np.expand_dims(face, axis=0) # convert (224,224,3) to (1,224,224,3)
faces.append(face)
return faces, frame
model = load_model("mask_detector.model")
vs = cv2.VideoCapture(0)
model.summary()
while True:
ret, frame = vs.read()
if not ret:
break
faces, frame = detectAllFaces(frame)
if len(faces):
print(f"{len(faces)} faces found")
maskPreds = model.predict(faces) # <==========
print(maskPreds)
cv2.imshow("Window", frame)
if cv2.waitKey(1) == ord('q'):
break
cv2.destroyWindow("Window")
vs.release()
注意:如果我不将每个图像从 (224, 224, 3) 转换为 (1, 224, 224, 3),则 tensorflow 会抛出错误,指出输入尺寸不匹配。
ValueError: Error when checking input: expected input_1 to have 4 dimensions, but got array with shape (224, 224, 3)
如何实现批量预测?
解决方案
在这种情况下,函数的输入model.predict()
需要作为形状为(N, 224, 224, 3)的 numpy 数组给出,其中 N 是输入图像的数量。
为此,我们可以将N个大小为(1, 224, 224, 3)的单独 numpy 数组堆叠成一个大小为(N, 224, 224, 3)的数组,然后将其传递给函数。model.predict()
maskPreds = model.predict(np.vstack(faces))
推荐阅读
- windows - 如何将子文件夹添加到 Windows 的路径中?
- javascript - 用导航反应原生问题”
- typescript - 如何使用 Typescript 从 Sharepoint 中的选择站点列中获取数据?
- python - 如何巧妙地将多个不同的值分配给多个不同的变量?
- tslint - tslint 空格规则是否适用于进口
- node.js - 散列字符串与散列 UInt8Array
- sql - 如果存在具有相同 ID 和特定类型的另一行,则省略行
- python - 在pyspark中使用date_add从日期列中减去一个int列
- elixir - 为什么 Elixir 运算符“in”在 for 循环中不起作用?
- amazon-web-services - 有没有办法完成时间以及 Amazon Elastic Transcoder 作业完成了多少处理