首页 > 解决方案 > 是否建议使用相同的 torch Dataset 类进行训练和预测?

问题描述

我最近开始使用 PyTorch,我喜欢它的面向对象风格。但是,我想知道在预测模型时最好和建议的工作流程是什么。我想使用我编写的自定义数据集类,用于训练和验证我的模型。这个类是一个地图风格的数据集,因此我实现__getitem__了返回图像和目标的方法:

class CustomDataset:

    def __init__(self, ...):
        ...

    def __getitem__(self, image_id):
        ....
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(target, dtype=torch.long),
       )

但是,当我使用此类进行预测时,我没有任何返回目标。我目前的解决方法是

def __getitem__(self, image_id):
    ....
    if predict:
        return (
            torch.tensor(image, dtype=torch.float),
            np.nan,
       )
   else:
        return (
             torch.tensor(image, dtype=torch.float),
             torch.tensor(target, dtype=torch.long),
       )

但是,我想知道是否有更好的方法来做到这一点。同时,由于感觉有点不自然,我开始想知道使用同一个类进行训练和预测是否是可取的(应该是,但我的解决方案的笨拙让我想知道)。当然,我根本无法返回元组,只能返回第一个元素,但这仍然需要 if-else。

标签: pythonpytorchtorchpytorch-dataloader

解决方案


PyTorch 的DataSet类非常简单。所以,不要想太多。它只不过是用于访问数据的包装器。

您不必返回元组,甚至不必返回张量。你可以返回任何你想要的数据。通常,它将采用以下样式之一:

  • 对于无监督数据:Sample(Sample, None)
  • 对于监督数据:(Sample, Label)
  • 对于具有多个目标的监督数据,例如对象检测:(Sample, [Label1, Label2, ...])(Sample, Label1, Label2, ...)

训练/测试使用相同的 DataSet 类也很常见。

(sample, None) 因此,在您的情况下,只需像在 torchvision 中所做的那样返回样本或元组并相应地调整您的管道。我不建议使用np.nan它,因为它会使简单的无检查 ( np.nan == None) 失败。另外,我鼓励你继承自torch.data.Dataset.

但是,如果您的管道迫使您使用元组或有其他限制,我建议您重新表述您的问题。


推荐阅读