pytorch - 如何将图像数据集转换为张量?
问题描述
我有一个看起来像这样的图像数据集:
array([[[[0.35980392, 0.26078431, 0.14313725],
[0.38137255, 0.26470588, 0.15196078],
[0.51960784, 0.3745098 , 0.26176471],
...,
[0.34313725, 0.22352941, 0.15 ],
[0.30784314, 0.2254902 , 0.15686275],
[0.28823529, 0.22843137, 0.16862745]],
[[0.38627451, 0.28235294, 0.16764706],
[0.45098039, 0.32843137, 0.21666667],
[0.62254902, 0.47254902, 0.36470588],
...,
[0.34607843, 0.22745098, 0.15490196],
[0.30686275, 0.2245098 , 0.15588235],
[0.27843137, 0.21960784, 0.16176471]],
[[0.41568627, 0.30098039, 0.18431373],
[0.51862745, 0.38529412, 0.27352941],
[0.67745098, 0.52058824, 0.40980392],
...,
[0.34901961, 0.22941176, 0.15588235],
[0.29901961, 0.21666667, 0.14901961],
[0.26078431, 0.20098039, 0.14313725]],
...,
我需要将它转换为张量,以便我可以将它传递给 CNN。我正在尝试这样做:
from torchvision import transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
如何将其transform
应用于我的数据集?谢谢你的帮助。
解决方案
您可能想要创建一个数据加载器。您将需要一个遍历数据集的类,您可以这样做:
import torch
import torchvision.transforms
class YourDataset(torch.utils.data.Dataset):
def __init__(self):
# load your dataset (how every you want, this example has the dataset stored in a json file
with open(<dataset-path>, "r") as f:
self.dataset = json.load(f)
def __getitem__(self, idx):
sample = self.dataset[idx]
data, label = sample[0], sample[1]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(data), torch.tensor(label)
def __len__(self):
return len(self.dataset)
现在您可以创建一个数据加载器:
train_set = YourDataset()
train_dataloader = torch.utils.data.DataLoader(
train_set,
batch_size=64,
num_workers=1,
shuffle=True,
)
现在您可以在训练循环中迭代数据加载器:
for samples, labels in self.train_set:
. . .
# samples will hold N samples of your dataset where N is the batchsize
如果您需要更多解释,请查看有关此主题的 pytorchs 文档。
推荐阅读
- r - 在另一个函数中创建 R 调查设计对象
- python - 在 Elasticsearch 中搜索句点和连字符分隔的字段
- plsql - pl/sql 创建行级触发器
- visual-studio-code - VS Code 光标跳到底部
- c# - 如何使用 csvhelper 读取文件夹中的多个 csv 文件
- ithit-webdav-server - 写入通过映射驱动器中的 Windows 资源管理器打开文件时所做的文件更改
- batch-file - 归档目录超过 N 天的批处理脚本
- javascript - 反应原生相机在导航时变为空白
- jquery - Jquery:将最后一行集中在textarea上的更改
- mysql - subTables 连接查询,任何数据库中间件建议