pytorch - 高分辨率图像的对象检测推理在 cpu 上花费大量时间
问题描述
我已经在图像大小为 224 的帕斯卡数据集上训练了 ML 模型,但是在推断新图像时(有些是高分辨率的,有些比帕斯卡图像的分辨率略高),我得到了错误pil2tensor()
@app.route('/analyze', methods=['POST'])
async def analyze(request):
data = await request.form()
img_bytes = await (data['file'].read())
img = open_image(BytesIO(img_bytes))
t_img= PIL.Image.open(BytesIO(img_bytes)).convert('RGB')
t_img = pil2tensor(t_img, np.float32)
t_img = t_img.div_(255)
with torch.no_grad():
# test_output = learn.model.eval()(t_img.unsqueeze_(0).cuda())
test_output = learn.model.eval()(t_img.unsqueeze_(0))
对于小尺寸图像(比如一些来自谷歌的低分辨率图像),ML 模型能够在几秒钟内正确地进行推理,但对于分辨率稍高的图像,大约需要 20-40 分钟!!!
解决方案
解决了这个问题,这是正确的代码:
@app.route('/analyze', methods=['POST'])
async def analyze(request):
data = await request.form()
img_bytes = await (data['file'].read())
img = open_image(BytesIO(img_bytes))
localtime = _utc_to_local(datetime.utcnow())
current_dir = os.path.dirname(__file__)
media_path = os.path.join(current_dir, "media")
media_path_original = os.path.join(media_path, "original")
media_path_processed = os.path.join(media_path, "processed")
img_path = os.path.join(media_path_original, localtime+".jpg")
img.save(img_path)
verify_image(Path(img_path), idx=0, delete=False, max_size=600, dest=Path(media_path_processed))
processed_img_path = os.path.join(media_path_processed, localtime+".jpg")
processed_img = open_image(processed_img_path) if os.path.exists(processed_img_path) else img
processed_img.refresh()
with torch.no_grad():
test_output = learn.model.eval()(processed_img.data.unsqueeze(0))
predictions = show_preds(processed_img, test_output, 0, detect_thresh=0.4, classes=t_classes)
img_height, img_width = processed_img.size
return JSONResponse({'predictions': predictions, 'img_height': img_height, 'img_width': img_width})
推荐阅读
- python - 在不离开页面的情况下更新结帐日志
- scheme - 什么时候使用define,什么时候使用let in racket
- c# - 如何从 C# 中传递的对象中检索单个值?
- javascript - 如何获取两个日期时间之间的日期数组
- docker - 副本会减少单节点 Kubernetes 集群上的流量吗?
- javascript - 如何使用云功能将子级添加到 Firebase 实时数据库中的列表?
- json - DynamoDB JSON 响应解析垂直打印
- r - 如何删除 R 中的 na 并使低于值上升
- c# - 如何设置条件 C# if 语句和 LINQ where 语句?
- ruby-on-rails - 如何降级gemfile中的gem?