首页 > 解决方案 > Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self'

问题描述

我正在使用 PyTorch Lightning 进行图像分类任务。但是我有一段TypeError时间实施它。我已经创建了数据模块和模型,如 PyTorch Lightning 示例中所示。我使用的模型是VGG16批量标准化。

FruitsDataModule我得到的错误只针对val_dataloader而不是train_dataloader令人困惑,因为这两个函数都在使用不同的数据做完全相同的事情。

相关代码如下所示。

数据模块

class FruitsDataModule(pl.LightningDataModule):
    
    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose(
            [
                transforms.ToTensor()
            ]
        )
        
    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:

            full_train_dataset = datasets.ImageFolder(
                root = config.TRAIN_DATA_PATH,
                transform = self.transform
            )
            
            train_dataset, val_dataset = train_test_split(
                full_train_dataset, 
                test_size=0.33,
                random_state = 42
            )
            
        if stage == 'test' or stage is None:

            test_dataset = datasets.ImageFolder(
                root = config.TEST_DATA_PATH,
                transform = self.transform
            )
            
    def train_dataloader(self):
        return DataLoader(
            train_dataset,
            batch_size = config.BATCH_SIZE,
            shuffle = True
        )
        
    def val_dataloader(self):
        return DataLoader(
            val_dataset,
            batch_size = config.BATCH_SIZE,
        )
    
    def test_dataloader(self):
        return DataLoader(
            test_dataset,
            batch_size = config.BATCH_SIZE,
        )

模型

class VGGModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        
        self.model = models.vgg16_bn(pretrained=True)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y
        
    
    def training_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)
        
        # log train metrics
        acc = roc_auc_score(preds, targets)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return {"loss": loss, "preds": preds, "targets": targets}
    
    
   
    def validation_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log val metrics
        acc = roc_auc_score(preds, targets)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "targets": targets}
    
    
    def test_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log test metrics
        acc = roc_auc_score(preds, targets)
        self.log("test/loss", loss, on_step=False, on_epoch=True)
        self.log("test/acc", acc, on_step=False, on_epoch=True)

        return {"loss": loss, "preds": preds, "targets": targets}
    
   
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = config.LEARNING_RATE)
        return optimizer

训练

model = VGGModel()

trainer = pl.Trainer(
    max_epochs=1,
    gpus=[0],
    precision=32,
    progress_bar_refresh_rate=20
)

trainer.fit(model, datamodule = FruitsDataModule)

错误日志

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-5990df7ecb15> in <module>
      8 )
      9 
---> 10 trainer.fit(model, datamodule = FruitsDataModule)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    497 
    498         # dispath `start_training` or `start_testing` or `start_predicting`
--> 499         self.dispatch()
    500 
    501         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    544 
    545         else:
--> 546             self.accelerator.start_training(self)
    547 
    548     def train_or_test_or_predict(self):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     71 
     72     def start_training(self, trainer):
---> 73         self.training_type_plugin.start_training(trainer)
     74 
     75     def start_testing(self, trainer):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    112     def start_training(self, trainer: 'Trainer') -> None:
    113         # double dispatch to initiate the training loop
--> 114         self._results = trainer.run_train()
    115 
    116     def start_testing(self, trainer: 'Trainer') -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    605             self.progress_bar_callback.disable()
    606 
--> 607         self.run_sanity_check(self.lightning_module)
    608 
    609         # set stage for logging

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
    852         # to make sure program won't crash during val
    853         if should_sanity_check:
--> 854             self.reset_val_dataloader(ref_model)
    855             self.num_sanity_val_batches = [
    856                 min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in reset_val_dataloader(self, model)
    362         has_step = is_overridden('validation_step', model)
    363         if has_loader and has_step:
--> 364             self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
    365 
    366     def reset_test_dataloader(self, model) -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in _reset_eval_dataloader(self, model, mode)
    276         # always get the loaders first so we can count how many there are
    277         loader_name = f'{mode}_dataloader'
--> 278         dataloaders = self.request_dataloader(getattr(model, loader_name))
    279 
    280         if not isinstance(dataloaders, list):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in request_dataloader(self, dataloader_fx)
    396             The dataloader
    397         """
--> 398         dataloader = dataloader_fx()
    399         dataloader = self._flatten_dl_only(dataloader)
    400 

TypeError: val_dataloader() missing 1 required positional argument: 'self'

如何解决这个问题?

标签: pythondeep-learningpytorchtorchvisionpytorch-lightning

解决方案


datamodule应该是对象,而不是类。因此,这

trainer.fit(model, datamodule=FruitsDataModule)

应该

trainer.fit(model, datamodule=FruitsDataModule())

推荐阅读