首页 > 解决方案 > Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”

问题描述

你能告诉我为什么我无法导入 CUFAR10DataModule() 吗?

起初,我在 GoogleColab 上运行代码,

from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()

然后,执行代码以进行确认

from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)

for epoch in range(10):
  for batch in dm.train_loader:
    x, y = batch
    with torch.no_grad():
      features = backbone(x)

    preds = finetune_layer(features)
    loss = cross_entropy(preds, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())

但是,AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'运行代码后返回了该消息。

dm当运行代码以确认

for batch in dm.train_dataloader:
  x, y = batch
  print(x.shape, y.shape)
  break

错误说TypeError: 'method' object is not iterable

代码与示例看起来相同,但我想知道为什么会产生这样的错误?

标签: pytorchgoogle-colaboratorypytorch-lightning

解决方案


您的代码有两个问题:

首先,获取底层 PyTorch 数据加载器的方式dm.train_dataloader()不是dm.train_loader. 它是一个函数,而不是一个属性

for batch in dm.train_dataloader():
    x, y = batch
    ...

其次,由于您尝试使用 aLightningDataModule而不使用 a Trainer,因此您需要手动调用

dm.prepare_data()
dm.setup()

.. 为了使数据加载器可以通过.train_dataloader().


推荐阅读