首页 > 解决方案 > 如何在 PyTorch 中从 CSV 读取数值数据?

问题描述

我是 PyTorch 的新手;尝试实现我在 TF 中开发的模型并比较结果。该模型是一个自动编码器模型。输入数据是一个 csv 文件,包括 n 个样本,每个样本具有 m 个特征(csv 文件中的一个 *m 数字矩阵)。目标(标签)位于另一个与输入文件格式相同的 csv 文件中。我一直在网上寻找,但找不到一个很好的文档来从具有多个标签的 csv 文件中读取非图像数据。知道如何在训练期间读取我的数据并对其进行迭代吗?

谢谢

标签: pytorch

解决方案


您可能正在寻找类似TabularDataset的东西吗?

类 torchtext.data.TabularDataset(路径、格式、字段、skip_header=False、csv_reader_params={}、**kwargs)

定义以 CSV、TSV 或 JSON 格式存储的列数据集。

它将获取 CSV 文件的路径并从中构建数据集。您还需要指定将成为数据字段的列的名称。

通常,针对特定类型数据的 torch.Dataset 的所有实现都位于 pytorch 之外的 torchvision、torchtext 和 torchaudio 库中。


推荐阅读