python - IndexError:维度为 2 的张量的索引过多
问题描述
这是数据集:
class price_dataset(Dataset):
def __init__(self, transform=None):
xy = pd.read_csv('data_balanced_full.csv')
self.n_samples = xy.shape[0]
xy = xy.to_numpy()
self.x_data = torch.from_numpy(xy[:, 7:].astype(np.float32))
self.y_data = torch.from_numpy(xy[:, 6].astype(np.float32))
self.transform = transform
def __getitem__(self, index):
x = self.x_data[index]
y = self.y_data[index]
sample = {'data': x, 'label': y}
if self.transform:
sample = self.transform(sample)
return sample
# we can call len(dataset) to return the size
def __len__(self):
return self.n_samples
我正在尝试将数据集拆分为测试和训练:
dataset_normalized = price_dataset(transform=transforms.ToTensor())
train_dataset, test_dataset = train_test_split(dataset_normalized['data'], dataset_normalized['label'], test_size=0.10, random_state=0)
但我收到此错误:
IndexError: too many indices for tensor of dimension 2
解决方案
'data'
并且'label'
不是索引而是字典的键。__getitem__
该字典可访问并按如下方式调用: dataset_normalized[idx]
idx 为整数。
此外,您不能直接在字典上调用您的转换。你应该调用它sample['data']
。
我建议你仔细阅读 PyTorch 文档的这个例子,它非常好。
推荐阅读
- javascript - 从 JS / jQuery 设置 iframe src 给出 DOMException:cross-origin
- reactjs - 如何在滚动时关闭 React Bootstrap Table2 可扩展行
- sql - 编译器如何评估以下查询?
- azure - 从 Node.js 函数访问 Azure Synapse Analytics
- c++ - 如何正确跟踪 C++ 中的内存分配和释放,并重载运算符新闻和运算符删除?
- .net - 如何阻止用户使用 JWT 创建自定义 POST 请求?
- vba - VBA代码从一个工作表到另一个工作表中搜索列数据并将相应的行数据粘贴到第一个工作表上
- excel - Excel VBA 正则表达式模式给出错误'5018'
- android - cipher.doFinal 上的 AEADBadTagException
- go - 为数组值创建索引