首页 > 解决方案 > 在 python 上使用 ms_ssim 比较两个图像

问题描述

我想比较两个图像,但 ms_ssim 想要 4D 张量

https://pypi.org/project/pytorch-msssim/

我试过了

from PIL import Image
from tqdm import tqdm
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
import torchvision
import numpy as np

topil=torchvision.transforms.ToPILImage()
totensor=torchvision.transforms.ToTensor()

def ssimcompare(path1:str,path2:str)->float:
    image1 = Image.open(path1)
    image2 = Image.open(path2) 
    #it1=np.expand_dims(totensor(topil(np.array(image1))), axis=0)
    #it2=np.expand_dims(totensor(topil(np.array(image2))), axis=0)
    #it1=totensor(np.expand_dims(np.array(image1), axis=0))
    #it2=totensor(np.expand_dims(np.array(image2), axis=0))
    it1=totensor(np.array(image1))
    it2=totensor(np.array(image2))
    valor=ms_ssim( it1 , it2, data_range=255, size_average=False )
    return valor

但我得到不同的错误

ValueError: Input images must be 4-d tensors.
TypeError: pic should be Tensor or ndarray. Got <class 'PIL.JpegImagePlugin.JpegImageFile'>.
AttributeError: 'numpy.ndarray' object has no attribute 'type'

标签: pythoncomparepytorchtensorssim

解决方案


问题是所有这些函数(和类)都需要批量图像作为输入。但是,由于图像是 3D,所以批次是 4D。

当您只有一个图像张量时,您可以将其“解压”成一个单项批次

it = it.unsqueeze(0)

但是,我不推荐这个pytorch_msssim包。您应该考虑使用piqapiqIQA 包,因为它们有据可查并且实现速度更快。

例如,

pip install piqa

然后,您的功能变为

from piqa import ssim

def msssim_compare(path1: str, path2: str) -> float:
    image1 = Image.open(path1)
    image2 = Image.open(path2)
    it1 = totensor(image1).unsqueeze(0)
    it2 = totensor(image2).unsqueeze(0)

    return ssim.msssim(it1, it2).squeeze(0)

推荐阅读