image - 如何加载和拆分图像数据集进行训练?
问题描述
我有一个简单的分类问题,葡萄柚图像有 2 个类别(“健康”和“病态”)。我已将图像放在文件夹中:
root
healthy
sick
我现在正在尝试加载图像并将它们分成训练集和验证集。我以为我应该使用ImageFolder,例如
imgs = ImageFolder(path, transform=transform)
'path' 包含根文件夹的路径。转换是之前定义的。imgs 是张量。
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) --> which one should I choose ?
])
我的数据集非常小,所以如果可能的话,我需要通过随机翻转和其他变换来虚拟增加它。
现在我想将集合分成 80% 的训练和 20% 的验证。我应该做这样的事情吗?
split = int(0.8 * len(imgs))
index_list = list(range(len(imgs)))
train_idx, valid_idx = index_list[:split], index_list[split:]
什么是下一个步骤?像这样的东西?
## create training and validation sampler objects
tr_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)
## create iterator objects for train and valid datasets
trainloader = torch.utils.data.DataLoader(cifar, batch_size=256, sampler=tr_sampler)
validloader = torch.utils.data.DataLoader(cifar, batch_size=256, sampler=val_sampler)
任何帮助表示赞赏。谢谢!
解决方案
After trying, it seems that the loding phase is ok. But how does the solver know which images are 'healthy' and which are 'sick'?
Then, how can I write the training loop? I tried this, but it fails:
## Training
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(raisin.parameters(), lr=0.001, momentum=0.9)
print("ok criteres")
epochs = 5
for epoch in range(epochs): # loop over the dataset multiple times
print(epoch)
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = raisin(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 20 == 19: # print every 20 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 20))
running_loss = 0.0
The problem comes from the 'labels' which does not seem to be good. I'd expect them to be either 0 or 1, but I get values from 0 to 9... What is wrong?
推荐阅读
- c# - 将私有 RSACryptoServiceProvider blob 导入 CNGKey.Import
- ssms - 如何修复此 SQL Server Management Studio 错误?
- dart - Dart 2 异步包预编译错误
- c# - .NET 之类的 PHP 实体框架
- ios - 如何使用单个 progressView 从一个或多个 wkwebview 更新估计进度
- python - 将多个参数传递给 Django 的 filter()
- bash - 在标准输出中显示命令输出,然后通过转换保存到文件?
- c# - 使用实体框架使用通用 DataTable 读取和写入数据库
- c# - .Sum() 函数返回错误结果
- android - Android requestfocus在未选中的radiogroup上并在每个未选中的radiogroup旁边显示错误