首页 > 解决方案 > pyTorch:创建数据集

问题描述

我得到了一个大熊猫数据框,其中包含基于时间的测量数据(传感器值,时间信息不需要成为网络的输入)。这个数据框被放入一个张量中,然后用 torch.Dataloader 加载。

数据加载器非常慢,GPU 利用率约为 3%。

正如您在下面的代码中看到的,数据集被放入张量,然后放入torch.utils.data.DataLoader。然后将此加载器与 enumerate(Dataloader) 一起使用。

在我的研究过程中,我发现了 torch.utils.data.TensorDataset,但是当我尝试将张量放入这种类型时,我得到了错误:TypeError: Variable data has to be a tensor, but got list

对此有什么建议吗?

我也从torch找到了采样器。但是什么时候使用这些,因为我正在获取数据加载器的数据,只是非常慢。

    dataset = big_dataframe_flt.values
    (looks like: array([[ 0.17114914, -0.67040386, -0.72875149, ..., -0.51023438,
     0.49735906, -0.74075046],
   [ 0.17114914, -0.67088608, -0.72631001, ..., -0.53046875,
     0.49741296, -0.74127526],)

    dataset = torch.tensor(dataset).float()
    dataset = torch.utils.data.TensorDataset(dataset)

    data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=100, shuffle=True, num_workers=4, pin_memory=True)

    for epoch in range(num_epochs):
        model.train()
        for batch_idx, data in enumerate(data_loader):
            data = Variable(data).to(device)
            recon_batch, mu, var = model(data)

            # Backprop and optimize
            loss = loss_function(recon_batch, data, mu, var)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


data = Variable(data).to(device)
TypeError: Variable data has to be a tensor, but got list

标签: pytorch

解决方案


您看到的错误是因为TensorDataset返回张量元组作为输出。如果你print(dataset[0])会看到(tensor([ 0.1711, -0.6704, -0.7288, -0.5102, 0.4974, -0.7408]),)而不是tensor([ 0.1711, -0.6704, -0.7288, -0.5102, 0.4974, -0.7408])。因此,您必须data = Variable(data[0]).to(device)在循环中编写或解构元组for batch_idx, (data, ) in enumerate(data_loader):。这应该允许您使用DataLoader.

话虽如此,如果您的数据是低维的,则DataLoader代码效率不高,因为您必须执行大量采样操作,并为每个采样操作支付 Python 解释器的开销 - 换句话说,此代码不会从矢量化中受益匪浅. 您可以通过增加工作人员的数量来部分解决这个问题,但最终性能最高的版本可能是编写自己的采样器,它使用长度为 100 的整数数组索引数据张量,而不是使用整数对其进行 100 次采样然后连接结果。根据您的 GPU 代码的成本和数据的维度,您可能会也可能无法使用简单的DataLoader解决方案使 GPU 饱和。


推荐阅读