首页 > 解决方案 > 使用预训练的 pytorch vgg16 模型及其类进行分类

问题描述

我用pytorch的预训练vgg16模型写了一个图像vgg分类模型。

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
import urllib
from skimage.transform import resize
from skimage import io
import yaml

# Downloading imagenet 1000 classes list
file = urllib. request. urlopen("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
classes = ''
for f in file:
  classes = classes +  f.decode("utf-8")
classes = yaml.load(classes)

# Downloading pretrained vgg16 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)

print(model)

for param in model.parameters():
    param.requires_grad = False


url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/dog.jpg", "dog.jpg")

image=io.imread(url)

plt.imshow(image)
plt.show()

# resize to 224x224x3
img = resize(image,(224,224,3))

plt.imshow(img)
plt.show()
# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = mean*img+std
img1 = np.clip(img1,0,1)

img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0,3,2,1) # batch_size x channels x height x width

model.eval()
pred = model(img1.float())
print(classes[torch.argmax(pred).numpy().tolist()])

该代码工作正常,但输出错误的类。我不确定我在哪里做错了,但如果我不得不猜测它可能是 imagenet yaml 类列表或规范化输入图像。谁能告诉我我在哪里犯了错误?

标签: pytorchclassificationtorchvgg-nettorchvision

解决方案


图像预处理存在一些问题。首先,归一化计算为(value - mean) / std),而不是value * mean + std。其次,不应将值裁剪为 [0, 1],归一化有意将值从 [0, 1] 移开。其次,作为 NumPy 数组的图像具有形状[height, width, 3],当您置换维度时,您交换高度和宽度维度,创建一个形状为[batch_size, channels, width, height]的张量。

img = resize(image,(224,224,3))


# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = (img1 - mean) / std

img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0, 3, 1, 2) # batch_size x channels x height x width

您可以使用torchvision.transforms.

from torchvision import transforms

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img = resize(image,(224,224,3))
img1 = preprocess(img)
img1 = img1.unsqueeze(0)

如果您使用 PIL 加载图像,您还可以通过添加transforms.Resize((224, 224))到预处理管道来调整图像大小,或者您甚至可以添加transforms.ToPILImage()以首先将图像转换为 PIL 图像(transforms.Resize需要 PIL 图像)。


推荐阅读