首页 > 解决方案 > 将 Torch Hub 的 SSD 推断图像保存在输出目录中

问题描述

我正在使用Pytorch SSD,它在来自 Torch HUB 的 COCO 数据集上加载预训练模型。以 API 格式修改代码以获取一些图像并检测其中的对象。

尝试/output使用 matplotlib 的.savefig()方法保存每个推理图像输出但出现错误:

import torch
import matplotlib.patches as patches
from matplotlib import pyplot as plt


class ObjectDetector:

    def __init__(self):
        self.precision = 'fp32'
        self.detect_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=self.precision)
        self.utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils')

    def process(self):
        self.fetch_images()
        self.create_model()
        self.display_detections()
        

    def fetch_images(self):
        """To get the images from a website"""
        self.images = ['http://images.cocodataset.org/val2017/000000397133.jpg','http://images.cocodataset.org/val2017/000000037777.jpg','http://images.cocodataset.org/val2017/000000252219.jpg']

        return self.images

    def create_model(self):
        self.detect_model.to('cuda')
        self.detect_model.eval()

        self.inputs = [self.utils.prepare_input(uri) for uri in self.images]
        tensor = self.utils.prepare_tensor(self.inputs, self.precision == 'fp16')

        with torch.no_grad():
            detections_batch = self.detect_model(tensor)

        results_per_input = self.utils.decode_results(detections_batch)
        self.best_results_per_input = [self.utils.pick_best(results, 0.40) for results in results_per_input]
        self.classes_to_labels = self.utils.get_coco_object_dictionary()

        return self.best_results_per_input, self.classes_to_labels

    def display_detections(self):
        
        output_dir = "../data/vision/ssd/output"
        
        for image_idx in range(len(self.best_results_per_input)):
            fig, ax = plt.subplots(figsize=(20, 10))
            # Show original, denormalized image...
            image = self.inputs[image_idx] / 2 + 0.5
            ax.imshow(image)
            # ...with detections
            bboxes, classes, confidences = self.best_results_per_input[image_idx]
            for idx in range(len(bboxes)):
                left, bot, right, top = bboxes[idx]
                x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
                rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
                ax.text(x, y, "{} {:.0f}%".format(self.classes_to_labels[classes[idx] - 1], confidences[idx] * 100), bbox=dict(facecolor='white', alpha=0.5))
        
            plt.savefig(output_dir + str(image) + '.jpg')
        plt.show()


if __name__== '__main__':
    det = ObjectDetector()
    det.process()
    del det

上面的代码抛出以下错误:

---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
<ipython-input-9-acbe775772c1> in <module>
     63 if __name__== '__main__':
     64     det = ObjectDetector()
---> 65     det.process()
     66 
     67     del det

<ipython-input-9-acbe775772c1> in process(self)
     14         self.fetch_images()
     15         self.create_model()
---> 16         self.display_detections()
     17 
     18 

<ipython-input-9-acbe775772c1> in display_detections(self)
     57                 ax.text(x, y, "{} {:.0f}%".format(self.classes_to_labels[classes[idx] - 1], confidences[idx] * 100), bbox=dict(facecolor='white', alpha=0.5))
     58 
---> 59             plt.savefig(output_dir + str(image) + '.jpg')
     60         plt.show()
     61 

~/venv38/lib/python3.8/site-packages/matplotlib/pyplot.py in savefig(*args, **kwargs)
    721 def savefig(*args, **kwargs):
    722     fig = gcf()
--> 723     res = fig.savefig(*args, **kwargs)
    724     fig.canvas.draw_idle()   # need this if 'transparent=True' to reset colors
    725     return res

~/venv38/lib/python3.8/site-packages/matplotlib/figure.py in savefig(self, fname, transparent, **kwargs)
   2201             self.patch.set_visible(frameon)
   2202 
-> 2203         self.canvas.print_figure(fname, **kwargs)
   2204 
   2205         if frameon:

~/venv38/lib/python3.8/site-packages/matplotlib/backend_bases.py in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)
   2117 
   2118             try:
-> 2119                 result = print_method(
   2120                     filename,
   2121                     dpi=dpi,

~/venv38/lib/python3.8/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
    356                 f"%(removal)s.  If any parameter follows {name!r}, they "
    357                 f"should be pass as keyword, not positionally.")
--> 358         return func(*args, **kwargs)
    359 
    360     return wrapper

~/venv38/lib/python3.8/site-packages/matplotlib/backends/backend_agg.py in print_jpg(self, filename_or_obj, dryrun, pil_kwargs, *args, **kwargs)
    597             pil_kwargs.setdefault("quality", rcParams["savefig.jpeg_quality"])
    598             pil_kwargs.setdefault("dpi", (self.figure.dpi, self.figure.dpi))
--> 599             return background.save(
    600                 filename_or_obj, format='jpeg', **pil_kwargs)
    601 

~/venv38/lib/python3.8/site-packages/PIL/Image.py in save(self, fp, format, **params)
   2153                 fp = builtins.open(filename, "r+b")
   2154             else:
-> 2155                 fp = builtins.open(filename, "w+b")
   2156 
   2157         try:

FileNotFoundError: [Errno 2] No such file or directory: '../data/vision/ssd/output[[[0.1050852  0.07895297 0.08367175]\n  [0.31462591 0.31466424 0.32513717]\n  [0.28277484 0.25506944 0.23508735]\n  ...\n  [0.42182888 0.27386384 0.07784647]\n  [0.67421166 0.57844825 0.39889071]\n  [0.554919   0.33316082 0.09618731]]\n\n [[0.05228582 0.03646781 0.0400054 ]\n  [0.06949542 0.06235639 0.05692344]\n  [0.25959795 0.18080175 0.18654409]\n  ...\n  [0.60428691 0.30419598 0.06168084]\n  [0.62523846 0.37480789 0.15464491]\n  [0.40595506 0.21335363 0.0789785 ]]\n\n [[0.10904118 0.11286539 0.09207947]\n  [0.0804173  0.04945466 0.03713621]\n  [0.24569849 0.12457102 0.1002835 ]\n  ...\n  [0.8473525  0.49805938 0.01584464]\n  [0.62128949 0.34659926 0.04259144]\n  [0.60784509 0.39757653 0.1146472 ]]\n\n ...\n\n [[0.54990582 0.37598903 0.20369267]\n  [0.5526588  0.38010985 0.19625383]\n  [0.56226779 0.38371096 0.20185737]\n  ...\n  [0.29863339 0.2165191  0.14226269]\n  [0.30894688 0.23059896 0.16393229]\n  [0.31879315 0.21973148 0.16671452]]\n\n [[0.54124921 0.37518263 0.19985079]\n  [0.54947818 0.38385507 0.19607851]\n  [0.54889008 0.37478852 0.18892228]\n  ...\n  [0.29478525 0.22002212 0.15326277]\n  [0.31478406 0.23243116 0.16062237]\n  [0.30818757 0.21890863 0.14786195]]\n\n [[0.53892612 0.37097071 0.1888549 ]\n  [0.54983966 0.38421659 0.19571689]\n  [0.55770917 0.38090676 0.18950984]\n  ...\n  [0.316164   0.24439232 0.16849774]\n  [0.32127783 0.23892493 0.16441515]\n  [0.30470566 0.21542674 0.14437993]]].jpg'

标签: python-3.xmatplotlibimage-processingcomputer-visionpytorch

解决方案


plt.savefig()可以这样保存。此外,可以import uuid使用生成随机唯一字符串。

def display_detections(self):
    
    output_dir = "../data/vision/ssd/output"
    
    for image_idx in range(len(self.best_results_per_input)):
        fig, ax = plt.subplots(figsize=(20, 10))
        # Show original, denormalized image...
        image = self.inputs[image_idx] / 2 + 0.5
        ax.imshow(image)
        # ...with detections
        bboxes, classes, confidences = self.best_results_per_input[image_idx]
        for idx in range(len(bboxes)):
            left, bot, right, top = bboxes[idx]
            x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
            rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            ax.text(x, y, "{} {:.0f}%".format(self.classes_to_labels[classes[idx] - 1], confidences[idx] * 100), bbox=dict(facecolor='white', alpha=0.5))
        plt.axis('off')
        plt.savefig(output_dir + "/" + "Image" + "_" + str(uuid.uuid4()))
    plt.show()

推荐阅读