首页 > 解决方案 > PyTorch 自定义 DataLoader 处理多个 CSV

问题描述

我正在尝试定义一个定制的 PyTorch DataLoader,它能够有效地从不同的巨大CSV 中读取,而无需将它们加载到内存中。问题定义如下。为简单起见,假设我有两个 CSV

1.csv:

1, 2, 3
4, 5, 6
7, 8, 9
2.csv:

10,11,12
13,14,15
16,17,18

为简单起见,我们还假设批量大小为 1。生成器应该产生两个张量:

Tensor_1: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Tensor_2: [10, 11, 12, 13, 14, 15, 16, 17, 18]

这是因为对于每个有效索引,我应用的历史窗口等于 2,然后我将样本展平。

按照什么是从多个 csv 文件加载数据的最快方法中的答案,我编写了以下代码:

import numpy as np
import pandas as pd
import glob
from functools import lru_cache
from pathlib import Path
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
import torch

@lru_cache()
def get_sample_count_by_file(path: Path) -> int:
    c = 0
    with path.open() as f:
        for line in f:
            c += 1
    return c

class CSVDataset:
    def __init__(self, csv_directory: str, extension: str = ".csv"):
        self.directory = Path(csv_directory)
        self.files = sorted((f, get_sample_count_by_file(f)) for f in self.directory.iterdir() if f.suffix == extension)
        self._sample_count = sum(f[-1] for f in self.files)

    def __len__(self):
        return self._sample_count

    def __getitem__(self, idx):
        current_count = 0

        history_window = 2
        my_idx=idx+2

        for file_, sample_count in self.files:
            if current_count <= my_idx < current_count + sample_count:
                break  
            current_count += sample_count

        file_idx = my_idx - current_count # the index we want to access in file_
        if file_idx < 2:
            file_idx += 2

        with file_.open() as f:
            data = []
            for i, line in enumerate(f):
                if i >= file_idx-history_window and i <= file_idx:
                    for v in line.split(","):
                        data.append(float(v))

            data = np.array(data)
            return torch.from_numpy(data)


dataset = CSVDataset("<PATH CONTAINING CSVs>")
loader = DataLoader(dataset, batch_size=1)

pprint(list(enumerate(loader)))

它非常适用于第一个文件,但是当它切换到第二个 CSV 时会出现问题(由于索引管理错误,存在一些重复)。我该如何解决这个问题?

标签: pytorchiteratorpytorch-dataloader

解决方案


如何CSVDataset仅使用一个 csv,然后使用torch.utils.data.ConcatDataset将所有单独的 csv 数据集连接成一个。只要每个索引中的索引CSVDataset是连贯的,Pytorch 就会为您处理索引。


推荐阅读