首页 > 解决方案 > IndexError:维度为 2 的张量的索引过多

问题描述

这是数据集:

class price_dataset(Dataset):
    def __init__(self, transform=None):
        xy = pd.read_csv('data_balanced_full.csv')

        self.n_samples = xy.shape[0]

        xy = xy.to_numpy()
        self.x_data = torch.from_numpy(xy[:, 7:].astype(np.float32))
        self.y_data = torch.from_numpy(xy[:, 6].astype(np.float32))
        self.transform = transform

    def __getitem__(self, index):
        x = self.x_data[index]
        y = self.y_data[index]

        sample = {'data': x, 'label': y}
        if self.transform:
            sample = self.transform(sample)

        return sample

    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

我正在尝试将数据集拆分为测试和训练:

dataset_normalized = price_dataset(transform=transforms.ToTensor())
train_dataset, test_dataset = train_test_split(dataset_normalized['data'], dataset_normalized['label'], test_size=0.10, random_state=0)

但我收到此错误:

IndexError: too many indices for tensor of dimension 2

标签: pythonpytorch

解决方案


'data'并且'label'不是索引而是字典的键。__getitem__该字典可访问并按如下方式调用: dataset_normalized[idx]idx 为整数。

此外,您不能直接在字典上调用您的转换。你应该调用它sample['data']

我建议你仔细阅读 PyTorch 文档的这个例子,它非常好。


推荐阅读