pytorch - PyTorch 自定义 DataLoader 处理多个 CSV
问题描述
我正在尝试定义一个定制的 PyTorch DataLoader,它能够有效地从不同的巨大CSV 中读取,而无需将它们加载到内存中。问题定义如下。为简单起见,假设我有两个 CSV
1.csv:
1, 2, 3
4, 5, 6
7, 8, 9
2.csv:
10,11,12
13,14,15
16,17,18
为简单起见,我们还假设批量大小为 1。生成器应该产生两个张量:
Tensor_1: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Tensor_2: [10, 11, 12, 13, 14, 15, 16, 17, 18]
这是因为对于每个有效索引,我应用的历史窗口等于 2,然后我将样本展平。
按照什么是从多个 csv 文件加载数据的最快方法中的答案,我编写了以下代码:
import numpy as np
import pandas as pd
import glob
from functools import lru_cache
from pathlib import Path
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
import torch
@lru_cache()
def get_sample_count_by_file(path: Path) -> int:
c = 0
with path.open() as f:
for line in f:
c += 1
return c
class CSVDataset:
def __init__(self, csv_directory: str, extension: str = ".csv"):
self.directory = Path(csv_directory)
self.files = sorted((f, get_sample_count_by_file(f)) for f in self.directory.iterdir() if f.suffix == extension)
self._sample_count = sum(f[-1] for f in self.files)
def __len__(self):
return self._sample_count
def __getitem__(self, idx):
current_count = 0
history_window = 2
my_idx=idx+2
for file_, sample_count in self.files:
if current_count <= my_idx < current_count + sample_count:
break
current_count += sample_count
file_idx = my_idx - current_count # the index we want to access in file_
if file_idx < 2:
file_idx += 2
with file_.open() as f:
data = []
for i, line in enumerate(f):
if i >= file_idx-history_window and i <= file_idx:
for v in line.split(","):
data.append(float(v))
data = np.array(data)
return torch.from_numpy(data)
dataset = CSVDataset("<PATH CONTAINING CSVs>")
loader = DataLoader(dataset, batch_size=1)
pprint(list(enumerate(loader)))
它非常适用于第一个文件,但是当它切换到第二个 CSV 时会出现问题(由于索引管理错误,存在一些重复)。我该如何解决这个问题?
解决方案
如何CSVDataset
仅使用一个 csv,然后使用torch.utils.data.ConcatDataset
将所有单独的 csv 数据集连接成一个。只要每个索引中的索引CSVDataset
是连贯的,Pytorch 就会为您处理索引。
推荐阅读
- python - Python websockets lib客户端持久连接(带类实现)
- javascript - React - 如何将图像复制到剪贴板?
- ios - 如何在 SwiftUI 中获取 MKMapView 方向
- javascript - 如何让本地通知插件在我的科尔多瓦应用程序上工作?
- sql-server - SQL Server 将 2 个 Unicode 字符解释为相同
- c# - 统一播放多个音频片段
- microsoft-graph-toolkit - 同一页面中带有 Sharepoint Provider 的多个 Web 部件错误
- python - 在 Python 中替代 Sum 函数以获得更好的 LP 时间性能
- c# - 将本地数据库中的日期加载到列表中,错误
- javascript - v-for 不在 DIV 标签上循环,但适用于 TR 标签