首页 > 解决方案 > 数据集中每类的图像数量,PyTorch

问题描述

我正在研究一个图像数据集,其中图像被分为 10 个类别(CIFAR10 数据集)。我正在使用 PyTorch。请,我想知道如何通过循环遍历数据集来确定每个类的图像数量。提前感谢您的回复。

标签: pytorch

解决方案


你可以做两次。

  • 首先创建一个字典 img_dict,其中包含 CIFAR10 数据集的所有类。将所有值初始化为 0。
  • 下一个循环遍历数据集并继续根据 img_dict 中的类键递增值
dataset = CIFAR10(root='data/', download=True, transform=ToTensor())
dataset_size = len(dataset)
classes = dataset.classes
num_classes = len(dataset.classes)
img_dict = {}
for i in range(num_classes):
    img_dict[classes[i]] = 0

for i in range(dataset_size):
    img, label = dataset[i]
    img_dict[classes[label]] += 1

img_dict

您将获得如下输出:

每类图像数量:
1


推荐阅读