python - 从 pytorch 访问文件夹中的图像并将其提供给 CNN 预测的推理方法
问题描述
import torch, torchvision
from torch import nn
from torch import optim
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
import requests
from PIL import Image
from io import BytesIO
import copy
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np
import torch.utils.data as data_utils
# indices = torch.arange(1000)
# tr_1k = data_utils.Subset(tr, indices)
from google.colab import drive
drive.mount('/content/gdrive')
numb_batch = 64
T = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
print('........get MNIST Data............')
train_data = torchvision.datasets.MNIST('mnist_data', train=True, download=True, transform=T)
val_data = torchvision.datasets.MNIST('mnist_data', train=False, download=True, transform=T)
print('slice')
indices_train = torch.arange(1000)
indices_test = torch.arange(1000)
tr_1k = data_utils.Subset(train_data, indices_train)
test_1k = data_utils.Subset(val_data, indices_test)
print('........Assign MNIST data as training and validation dataset...........')
train_dl = torch.utils.data.DataLoader(tr_1k, batch_size = numb_batch)
val_dl = torch.utils.data.DataLoader(test_1k, batch_size = numb_batch)
def create_lenet():
print('.......create the convolutional neural network..........')
model = nn.Sequential(
nn.Conv2d(1, 6, 5, padding=2),
nn.ReLU(),
nn.AvgPool2d(2, stride=2),
nn.Conv2d(6, 16, 5, padding=0),
nn.ReLU(),
nn.AvgPool2d(2, stride=2),
nn.Flatten(),
nn.Linear(400, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)
return model
def validate(model, data):
print('.......validate...........')
total = 0
correct = 0
for i, (images, labels) in enumerate(data):
# images = images.cuda()
x = model(images)
value, pred = torch.max(x,1)
pred = pred.data.cpu()
total += x.size(0)
correct += torch.sum(pred == labels)
return correct*100./total
def train(numb_epoch=3, lr=1e-3, device="cpu"):
print('.......train the model................')
accuracies = []
cnn = create_lenet().to(device)
cec = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=lr)
max_accuracy = 0
for epoch in range(numb_epoch):
for i, (images, labels) in enumerate(train_dl):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
pred = cnn(images)
loss = cec(pred, labels)
loss.backward()
optimizer.step()
accuracy = float(validate(cnn, val_dl))
accuracies.append(accuracy)
if accuracy > max_accuracy:
best_model = copy.deepcopy(cnn)
max_accuracy = accuracy
print("Saving Best Model with Accuracy: ", accuracy)
print('Epoch:', epoch+1, "Accuracy :", accuracy, '%')
plt.plot(accuracies)
return best_model
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
print("No Cuda Available")
device
print('......start training............')
lenet = train(40, device="cpu")
print('..........save the model...........')
torch.save(lenet.state_dict(), "new.pth")
# lenet = create_lenet().to(device)
# lenet.load_state_dict(torch.load("lenet.pth"))
# lenet.eval()
def predict_dl(model, data):
print('use the model to predict')
y_pred = []
y_true = []
for i, (images, labels) in enumerate(data):
# images = images.cuda()
x = model(images)
value, pred = torch.max(x, 1)
pred = pred.data.cpu()
y_pred.extend(list(pred.numpy()))
y_true.extend(list(labels.numpy()))
return np.array(y_pred), np.array(y_true)
y_pred, y_true = predict_dl(lenet, val_dl)
pd.DataFrame(confusion_matrix(y_true, y_pred, labels=np.arange(0,10)))
def inference(path, model, device):
r = requests.get(path)
with BytesIO(r.content) as f:
img = Image.open(f).convert(mode="L")
img = img.resize((28, 28))
x = (255 - np.expand_dims(np.array(img), -1))/255.
with torch.no_grad():
pred = model(torch.unsqueeze(T(x), axis=0).float().to(device))
return F.softmax(pred, dim=-1).cpu().numpy()
def inference2(images, model, device):
with torch.no_grad():
pred = model(torch.unsqueeze(T(images), axis=0).float().to(device))
return F.softmax(pred, dim=-1).cpu().numpy()
path = "https://previews.123rf.com/images/aroas/aroas1704/aroas170400068/79321959-handwritten-sketch-black-number-8-on-white-background.jpg"
# path2 = 'https://drive.google.com/file/d/1-t6EEaGFwsq4dsd8UDVVdv4yxCkr7iFW/view?usp=sharing'
path3 = '/content/100images/fake_images-10.png'
r = requests.get(path)
with BytesIO(r.content) as f:
img = Image.open(f).convert(mode="L")
img = img.resize((28, 28))
x = (255 - np.expand_dims(np.array(img), -1))/255.
plt.imshow(x.squeeze(-1), cmap="gray")
pred = inference(path, lenet, device=device)
pred_idx = np.argmax(pred)
print(f"Predicted: {pred_idx}, Prob: {pred[0][pred_idx]*100} %")
print('get final fake 100 images')
indices_fake = torch.arange(100)
tr_fake = data_utils.Subset(train_data, indices_fake)
train_dl = torch.utils.data.DataLoader(tr_fake, batch_size = 1)
# '/content/gdrive/MyDrive/2/'
tsr_img1 = torchvision.io.read_image('/content/gdrive/MyDrive/100images/fake_image-10.png')
# for loop to get 100 fake images and cllassify from this model
image_datasets = torchvision.datasets.folder('/content/100images/', transform=T)
dataloadersdataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=1, shuffle=True)
for i, (images,_) in enumerate(tr_fake):
print(i)
images = images.view(-1, 784).to(device)
# with BytesIO(images.content) as f:
# img = Image.open(f).convert(mode="L")
# img = img.resize((28, 28))
# x = (255 - np.expand_dims(np.array(images), -1))/255.
print(images)
data = images.reshape(-1, 1, 28, 28)
print(data)
img = data.resize((28, 28))
x = (255 - np.expand_dims(np.array(img), -1))/255.
plt.imshow(x.squeeze(-1) , cmap="gray")
pred = inference2(images, lenet, device=device)
pred_idx = np.argmax(pred)
print(f"Predicted: {pred_idx}, Prob: {pred[0][pred_idx]*100} %")
我有以下代码。在这里,我有一个训练有素的 CNN,它可以预测和分类手写数字。这是使用 mnist 数据集训练的。我想做的是使用其中的推理方法对我自己的数据集进行预测。
我有一个从 mnist 训练集中获取的包含 100 位数字的数据集。我将它们放在 google collab 根位置以及我的 google drive 中。我想获取那些图像(在文件夹中保存为 png s)并通过推理方法运行它们并预测它是什么数字。我无法这样做。我尝试了多种方法未能成功地从文件夹。我不断收到这样的文件夹不存在错误。其次,一旦获取,我需要遍历它们并获得对它们中的每一个的预测。此处的任何帮助将不胜感激。首先获取图像,然后对其进行预测。此处已对从网络上获取的单个图像进行了预测。我需要对其进行修改以预测我们提供的图像。
解决方案
推荐阅读
- python - Python类结构,获取和设置属性和继承
- jgit - JGit 抛出 MissingObjectException 并说缺少未知的提交 ID
- bash - 超时命令未超时
- python - 为什么我在 bolt-python 中同时需要 SlackInstallation 和 SlackBot 模型?
- linux - 从java中的输出中删除ANSII颜色
- coq - (已解决 - 是一个目标打印错误) Coq - 列表上的简单归纳不接受假设
- python - 从.txt中提取多行到字典
- installation - 如果应用程序已打开并且用户想要卸载它,如何添加警告消息
- firebase - 使用 Firebase 进行 Flutter RealTime Online 和 Offline
- java - 排除来自依赖项的 jar