首页 > 解决方案 > 如何在使用 pytorch_geometric Data 对象时使 DataLoader 加载正确的 batch_size 数据?

问题描述

我是pytorch_geometric的新手。

这是我创建包含 5 个样本的 DataLoader 的代码,每个样本由 1 * 10 矩阵表示。

import torch
import numpy as np
from torch_geometric.data import Data, DataLoader

x = torch.tensor(np.random((5, 10))
y = torch.tensor(np.random((5, 1))
data = Data(x = x, y = y)
loader = DataLoader([data], batch_size=2)

当我将 batch_size 设置为 2 时,我希望加载器将我的 5 个样本加载 2 2 1,但现在它会一次性推出所有数据,下面的代码是打印结果。

for batch in loader:
    print(batch)

# Batch(batch=[5], x=[5, 10], y=[5, 1])

如果您熟悉,请告诉我如何才能达到我想要的结果。

顺便说一句,我是从像这样的 pytorch_geometric 官方存储库数据集中学到的。

标签: deep-learningpytorch

解决方案


推荐阅读