python-3.x - 将 Torch Hub 的 SSD 推断图像保存在输出目录中
问题描述
我正在使用Pytorch SSD,它在来自 Torch HUB 的 COCO 数据集上加载预训练模型。以 API 格式修改代码以获取一些图像并检测其中的对象。
尝试/output
使用 matplotlib 的.savefig()
方法保存每个推理图像输出但出现错误:
import torch
import matplotlib.patches as patches
from matplotlib import pyplot as plt
class ObjectDetector:
def __init__(self):
self.precision = 'fp32'
self.detect_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=self.precision)
self.utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils')
def process(self):
self.fetch_images()
self.create_model()
self.display_detections()
def fetch_images(self):
"""To get the images from a website"""
self.images = ['http://images.cocodataset.org/val2017/000000397133.jpg','http://images.cocodataset.org/val2017/000000037777.jpg','http://images.cocodataset.org/val2017/000000252219.jpg']
return self.images
def create_model(self):
self.detect_model.to('cuda')
self.detect_model.eval()
self.inputs = [self.utils.prepare_input(uri) for uri in self.images]
tensor = self.utils.prepare_tensor(self.inputs, self.precision == 'fp16')
with torch.no_grad():
detections_batch = self.detect_model(tensor)
results_per_input = self.utils.decode_results(detections_batch)
self.best_results_per_input = [self.utils.pick_best(results, 0.40) for results in results_per_input]
self.classes_to_labels = self.utils.get_coco_object_dictionary()
return self.best_results_per_input, self.classes_to_labels
def display_detections(self):
output_dir = "../data/vision/ssd/output"
for image_idx in range(len(self.best_results_per_input)):
fig, ax = plt.subplots(figsize=(20, 10))
# Show original, denormalized image...
image = self.inputs[image_idx] / 2 + 0.5
ax.imshow(image)
# ...with detections
bboxes, classes, confidences = self.best_results_per_input[image_idx]
for idx in range(len(bboxes)):
left, bot, right, top = bboxes[idx]
x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.text(x, y, "{} {:.0f}%".format(self.classes_to_labels[classes[idx] - 1], confidences[idx] * 100), bbox=dict(facecolor='white', alpha=0.5))
plt.savefig(output_dir + str(image) + '.jpg')
plt.show()
if __name__== '__main__':
det = ObjectDetector()
det.process()
del det
上面的代码抛出以下错误:
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
<ipython-input-9-acbe775772c1> in <module>
63 if __name__== '__main__':
64 det = ObjectDetector()
---> 65 det.process()
66
67 del det
<ipython-input-9-acbe775772c1> in process(self)
14 self.fetch_images()
15 self.create_model()
---> 16 self.display_detections()
17
18
<ipython-input-9-acbe775772c1> in display_detections(self)
57 ax.text(x, y, "{} {:.0f}%".format(self.classes_to_labels[classes[idx] - 1], confidences[idx] * 100), bbox=dict(facecolor='white', alpha=0.5))
58
---> 59 plt.savefig(output_dir + str(image) + '.jpg')
60 plt.show()
61
~/venv38/lib/python3.8/site-packages/matplotlib/pyplot.py in savefig(*args, **kwargs)
721 def savefig(*args, **kwargs):
722 fig = gcf()
--> 723 res = fig.savefig(*args, **kwargs)
724 fig.canvas.draw_idle() # need this if 'transparent=True' to reset colors
725 return res
~/venv38/lib/python3.8/site-packages/matplotlib/figure.py in savefig(self, fname, transparent, **kwargs)
2201 self.patch.set_visible(frameon)
2202
-> 2203 self.canvas.print_figure(fname, **kwargs)
2204
2205 if frameon:
~/venv38/lib/python3.8/site-packages/matplotlib/backend_bases.py in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)
2117
2118 try:
-> 2119 result = print_method(
2120 filename,
2121 dpi=dpi,
~/venv38/lib/python3.8/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
356 f"%(removal)s. If any parameter follows {name!r}, they "
357 f"should be pass as keyword, not positionally.")
--> 358 return func(*args, **kwargs)
359
360 return wrapper
~/venv38/lib/python3.8/site-packages/matplotlib/backends/backend_agg.py in print_jpg(self, filename_or_obj, dryrun, pil_kwargs, *args, **kwargs)
597 pil_kwargs.setdefault("quality", rcParams["savefig.jpeg_quality"])
598 pil_kwargs.setdefault("dpi", (self.figure.dpi, self.figure.dpi))
--> 599 return background.save(
600 filename_or_obj, format='jpeg', **pil_kwargs)
601
~/venv38/lib/python3.8/site-packages/PIL/Image.py in save(self, fp, format, **params)
2153 fp = builtins.open(filename, "r+b")
2154 else:
-> 2155 fp = builtins.open(filename, "w+b")
2156
2157 try:
FileNotFoundError: [Errno 2] No such file or directory: '../data/vision/ssd/output[[[0.1050852 0.07895297 0.08367175]\n [0.31462591 0.31466424 0.32513717]\n [0.28277484 0.25506944 0.23508735]\n ...\n [0.42182888 0.27386384 0.07784647]\n [0.67421166 0.57844825 0.39889071]\n [0.554919 0.33316082 0.09618731]]\n\n [[0.05228582 0.03646781 0.0400054 ]\n [0.06949542 0.06235639 0.05692344]\n [0.25959795 0.18080175 0.18654409]\n ...\n [0.60428691 0.30419598 0.06168084]\n [0.62523846 0.37480789 0.15464491]\n [0.40595506 0.21335363 0.0789785 ]]\n\n [[0.10904118 0.11286539 0.09207947]\n [0.0804173 0.04945466 0.03713621]\n [0.24569849 0.12457102 0.1002835 ]\n ...\n [0.8473525 0.49805938 0.01584464]\n [0.62128949 0.34659926 0.04259144]\n [0.60784509 0.39757653 0.1146472 ]]\n\n ...\n\n [[0.54990582 0.37598903 0.20369267]\n [0.5526588 0.38010985 0.19625383]\n [0.56226779 0.38371096 0.20185737]\n ...\n [0.29863339 0.2165191 0.14226269]\n [0.30894688 0.23059896 0.16393229]\n [0.31879315 0.21973148 0.16671452]]\n\n [[0.54124921 0.37518263 0.19985079]\n [0.54947818 0.38385507 0.19607851]\n [0.54889008 0.37478852 0.18892228]\n ...\n [0.29478525 0.22002212 0.15326277]\n [0.31478406 0.23243116 0.16062237]\n [0.30818757 0.21890863 0.14786195]]\n\n [[0.53892612 0.37097071 0.1888549 ]\n [0.54983966 0.38421659 0.19571689]\n [0.55770917 0.38090676 0.18950984]\n ...\n [0.316164 0.24439232 0.16849774]\n [0.32127783 0.23892493 0.16441515]\n [0.30470566 0.21542674 0.14437993]]].jpg'
解决方案
plt.savefig()
可以这样保存。此外,可以import uuid
使用生成随机唯一字符串。
def display_detections(self):
output_dir = "../data/vision/ssd/output"
for image_idx in range(len(self.best_results_per_input)):
fig, ax = plt.subplots(figsize=(20, 10))
# Show original, denormalized image...
image = self.inputs[image_idx] / 2 + 0.5
ax.imshow(image)
# ...with detections
bboxes, classes, confidences = self.best_results_per_input[image_idx]
for idx in range(len(bboxes)):
left, bot, right, top = bboxes[idx]
x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.text(x, y, "{} {:.0f}%".format(self.classes_to_labels[classes[idx] - 1], confidences[idx] * 100), bbox=dict(facecolor='white', alpha=0.5))
plt.axis('off')
plt.savefig(output_dir + "/" + "Image" + "_" + str(uuid.uuid4()))
plt.show()
推荐阅读
- java - 点击后从 RecyclerView 获取对象
- python - 为用户创建设置选项的优雅解决方案
- three.js - 如何将gltf素材更改为卡通素材?
- slack-api - 继续为新的松弛应用程序获取 users.info 的 invalid_auth(甚至从清单中复制!)
- javascript - 如何使用条件返回不同的代码块
- java - --rerun-tasks 和 --refresh-dependencies 之间的区别
- python - 当我将“django_social_share”添加到我安装的应用程序时,我得到“内部服务器错误”
- python-3.x - python将Dataframe列文本传输到枕头图像
- go - 返回一个空字符串会导致 GO 出现恐慌
- sql - 逐字段比较 2 个表数据并仅提取不匹配的行 Oracle