python - ValueError:采样器选项与 shuffle pytorch 互斥
问题描述
我正在使用 pytorch 和 mtcnn 进行人脸识别项目,在训练了我的训练数据集之后,现在我想对测试数据集进行预测
这是我训练有素的代码
optimizer = optim.Adam(resnet.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, [5, 10])
trans = transforms.Compose([
np.float32,
transforms.ToTensor(),
fixed_image_standardization
])
dataset = datasets.ImageFolder(data_dir, transform=trans)
img_inds = np.arange(len(dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8 * len(img_inds))]
val_inds = img_inds[int(0.8 * len(img_inds)):]
train_loader = DataLoader(
dataset,
num_workers=workers,
batch_size=batch_size,
sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
dataset,
shuffle=True,
num_workers=workers,
batch_size=batch_size,
sampler=SubsetRandomSampler(val_inds)
)
如果删除sampler=SubsetRandomSampler(val_inds)
并val_inds
改为放置,则会出现此错误
val_inds ^ SyntaxError:位置参数跟随关键字参数
我想在 pytorch 中进行预测(从测试数据集中随机选择)?这就是为什么我应该使用shuffle=True
我遵循这个 repo facenet-pytorch
解决方案
TLDR;shuffle=True
在这种情况下删除,因为SubsetRandomSampler
已经对数据进行了洗牌。
什么torch.utils.data.SubsetRandomSampler
(如有疑问,请查阅文档)将获取索引列表并返回它们的排列。
在您的情况下,您indices
对应于training
(这些是训练 Dataset 中元素的索引)和validation
.
让我们假设那些看起来像这样:
train_indices = [0, 2, 3, 4, 5, 6, 9, 10, 12, 13, 15]
val_indices = [1, 7, 8, 11, 14]
在每次传递期间SubsetRandomSampler
,将从这些列表中随机返回一个数字,并且在所有这些列表都返回后将再次随机化(__iter__
将再次调用)。
所以SubsetRandomSampler
可能会为val_indices
(类似地train_indices
)返回类似的东西:
val_indices = [1, 8, 11, 7, 14] # Epoch 1
val_indices = [11, 7, 8, 14, 1] # Epoch 2
val_indices = [7, 1, 14, 8, 11] # Epoch 3
现在,这些数字中的每一个都是您原始dataset
. 请注意validation
以这种方式洗牌,因此train
不使用shuffle=True
. 这些索引不重叠,因此数据被正确拆分。
附加信息
shuffle
torch.utils.data.RandomSampler
如果shuffle=True
指定,请在后台使用,请参阅源代码。这又等同于使用torch.utils.data.SubsetRandomSampler
所有指定的索引(np.arange(len(datatest))
)。- 你不必预先洗牌
np.random.shuffle(img_inds)
,因为无论如何索引都会在每次通过时被洗牌 - 不要使用
numpy
iftorch
提供相同的功能。有torch.arange
,几乎不需要混合这两个库。
推理
单张图片
只需通过您的网络将其传递给获取输出,例如:
module.eval()
with torch.no_grad():
output = module(dataset[5380])
第一行将模型置于评估模式(更改某些层的行为),上下文管理器关闭梯度(因为预测不需要它)。这些几乎总是在“检查神经网络输出”时使用。
检查验证数据集
沿着这些思路,请注意适用于单个图像的相同想法:
module.eval()
total_batches = 0
batch_accuracy = 0
for images, labels in val_loader:
total_batches += 1
with torch.no_grad():
output = module(images)
# In case it outputs logits without activation
# If it outputs activation you may have to use argmax or > 0.5 for binary case
# Item gets float from torch.tensor
batch_accuracy += torch.mean(labels == (output > 0.0)).item()
print("Overall accuracy: {}".format(batch_accuracy / total_batches))
其他案例
请查看一些初学者指南或教程并理解这些概念,因为 StackOverflow 不是重新做这项工作的地方(而是具体和小问题),谢谢。
推荐阅读
- regex - 如何在 Elasticsearch 中高效搜索动态定义的正则表达式?
- c# - ASP.NET 增加特定 API 调用的上传最大大小
- ruby-on-rails - 使用 Ruby eval 方法真的很危险吗?如果是,还有什么替代方法?(导轨)
- php - 单击“添加到播放列表按钮”后,从 PHP 数据库中获取特定行以添加到新的空数据库中
- session - Laravel 会话不会改变
- vue.js - 使用 vue-cli 找不到模块“./src/data”
- go - 使用 SDK 上传后从 S3 中删除对象
- python - OpenCV - 获取所有 Blob 像素
- mybb - 如何四舍五入 MyBB 查看计数格式
- angular - JSONP 方法不能在 localhost 上使用 Angular 8