python - Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self'
问题描述
我正在使用 PyTorch Lightning 进行图像分类任务。但是我有一段TypeError
时间实施它。我已经创建了数据模块和模型,如 PyTorch Lightning 示例中所示。我使用的模型是VGG16
批量标准化。
在FruitsDataModule
我得到的错误只针对val_dataloader
而不是train_dataloader
令人困惑,因为这两个函数都在使用不同的数据做完全相同的事情。
相关代码如下所示。
数据模块
class FruitsDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.transform = transforms.Compose(
[
transforms.ToTensor()
]
)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
full_train_dataset = datasets.ImageFolder(
root = config.TRAIN_DATA_PATH,
transform = self.transform
)
train_dataset, val_dataset = train_test_split(
full_train_dataset,
test_size=0.33,
random_state = 42
)
if stage == 'test' or stage is None:
test_dataset = datasets.ImageFolder(
root = config.TEST_DATA_PATH,
transform = self.transform
)
def train_dataloader(self):
return DataLoader(
train_dataset,
batch_size = config.BATCH_SIZE,
shuffle = True
)
def val_dataloader(self):
return DataLoader(
val_dataset,
batch_size = config.BATCH_SIZE,
)
def test_dataloader(self):
return DataLoader(
test_dataset,
batch_size = config.BATCH_SIZE,
)
模型
class VGGModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.vgg16_bn(pretrained=True)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
x = self.model(x)
return x
def step(self, batch):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = torch.argmax(logits, dim=1)
return loss, preds, y
def training_step(self, batch, batch_idx):
loss, preds, targets = self.step(batch)
# log train metrics
acc = roc_auc_score(preds, targets)
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return {"loss": loss, "preds": preds, "targets": targets}
def validation_step(self, batch, batch_idx):
loss, preds, targets = self.step(batch)
# log val metrics
acc = roc_auc_score(preds, targets)
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return {"loss": loss, "preds": preds, "targets": targets}
def test_step(self, batch, batch_idx):
loss, preds, targets = self.step(batch)
# log test metrics
acc = roc_auc_score(preds, targets)
self.log("test/loss", loss, on_step=False, on_epoch=True)
self.log("test/acc", acc, on_step=False, on_epoch=True)
return {"loss": loss, "preds": preds, "targets": targets}
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = config.LEARNING_RATE)
return optimizer
训练
model = VGGModel()
trainer = pl.Trainer(
max_epochs=1,
gpus=[0],
precision=32,
progress_bar_refresh_rate=20
)
trainer.fit(model, datamodule = FruitsDataModule)
错误日志
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-5990df7ecb15> in <module>
8 )
9
---> 10 trainer.fit(model, datamodule = FruitsDataModule)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
497
498 # dispath `start_training` or `start_testing` or `start_predicting`
--> 499 self.dispatch()
500
501 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
544
545 else:
--> 546 self.accelerator.start_training(self)
547
548 def train_or_test_or_predict(self):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
71
72 def start_training(self, trainer):
---> 73 self.training_type_plugin.start_training(trainer)
74
75 def start_testing(self, trainer):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
112 def start_training(self, trainer: 'Trainer') -> None:
113 # double dispatch to initiate the training loop
--> 114 self._results = trainer.run_train()
115
116 def start_testing(self, trainer: 'Trainer') -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
605 self.progress_bar_callback.disable()
606
--> 607 self.run_sanity_check(self.lightning_module)
608
609 # set stage for logging
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
852 # to make sure program won't crash during val
853 if should_sanity_check:
--> 854 self.reset_val_dataloader(ref_model)
855 self.num_sanity_val_batches = [
856 min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in reset_val_dataloader(self, model)
362 has_step = is_overridden('validation_step', model)
363 if has_loader and has_step:
--> 364 self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
365
366 def reset_test_dataloader(self, model) -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in _reset_eval_dataloader(self, model, mode)
276 # always get the loaders first so we can count how many there are
277 loader_name = f'{mode}_dataloader'
--> 278 dataloaders = self.request_dataloader(getattr(model, loader_name))
279
280 if not isinstance(dataloaders, list):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in request_dataloader(self, dataloader_fx)
396 The dataloader
397 """
--> 398 dataloader = dataloader_fx()
399 dataloader = self._flatten_dl_only(dataloader)
400
TypeError: val_dataloader() missing 1 required positional argument: 'self'
如何解决这个问题?
解决方案
datamodule
应该是对象,而不是类。因此,这
trainer.fit(model, datamodule=FruitsDataModule)
应该
trainer.fit(model, datamodule=FruitsDataModule())
推荐阅读
- powershell - 将 CSV 文件列导出到其他 CSV 文件
- java - 一个带开关盒的简单计算器
- c++ - 将函数作为参数传递给类 C++
- ionic-framework - 防止硬件后退按钮使用android关闭模式
- ios - 如何更新应用商店链接。我可以像示例一样更改应用商店网址吗
- android - Navigation Drawer 用另一个有问题地保留抽屉选项的片段替换片段
- java - FitNesse with JUnit:并行执行套件
- powerbi - 忽略堆积条形图中的空白
- javascript - 无法解决 javascript 中的承诺
- java - 哪个存储库具有 eXist 以及如何使用 gradle 将其添加到类路径中?