首页 > 解决方案 > PyTorch Lightning 是否在整个时期内平均指标?

问题描述

我正在查看PyTorch-Lightning官方文档https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html上提供的示例。

这里的损失和度量是在具体批次上计算的。但是当记录一个特定批次的准确性时,它可能相当小且不具有代表性,而是对所有时期的平均值不感兴趣。我是否理解正确,有一些代码对所有批次执行平均,通过时代?

 import pytorch_lightning as pl
 from pytorch_lightning.metrics import functional as FM

 class ClassificationTask(pl.LightningModule):

 def __init__(self, model):
     super().__init__()
     self.model = model

 def training_step(self, batch, batch_idx):
     x, y = batch
     y_hat = self.model(x)
     loss = F.cross_entropy(y_hat, y)
     return pl.TrainResult(loss)

 def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    acc = FM.accuracy(y_hat, y)
    result = pl.EvalResult(checkpoint_on=loss)
    result.log_dict({'val_acc': acc, 'val_loss': loss})
    return result

 def test_step(self, batch, batch_idx):
    result = self.validation_step(batch, batch_idx)
    result.rename_keys({'val_acc': 'test_acc', 'val_loss': 'test_loss'})
    return result

 def configure_optimizers(self):
     return torch.optim.Adam(self.model.parameters(), lr=0.02)

标签: pytorchmetricspytorch-lightning

解决方案


如果您想对整个时期的指标进行平均,您需要告诉LightningModule您已经子类化了这样做。有几种不同的方法可以做到这一点,例如:

  1. result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 如文档中所示调用,on_epoch=True以便在整个时期内平均训练损失。IE:
 def training_step(self, batch, batch_idx):
     x, y = batch
     y_hat = self.model(x)
     loss = F.cross_entropy(y_hat, y)
     result = pl.TrainResult(loss)
     result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
     return result
  1. 或者,您可以在其自身上调用该log方法:(可选地通过加速器来减少)。LightningModuleself.log("train_loss", loss, on_epoch=True, sync_dist=True)sync_dist=True

您需要做类似的事情validation_step来获取聚合的 val-set 指标或在validation_epoch_end方法中自己实现聚合。


推荐阅读