pytorch - PyTorch Lightning 是否在整个时期内平均指标?
问题描述
我正在查看PyTorch-Lightning
官方文档https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html上提供的示例。
这里的损失和度量是在具体批次上计算的。但是当记录一个特定批次的准确性时,它可能相当小且不具有代表性,而是对所有时期的平均值不感兴趣。我是否理解正确,有一些代码对所有批次执行平均,通过时代?
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return pl.TrainResult(loss)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss})
return result
def test_step(self, batch, batch_idx):
result = self.validation_step(batch, batch_idx)
result.rename_keys({'val_acc': 'test_acc', 'val_loss': 'test_loss'})
return result
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)
解决方案
如果您想对整个时期的指标进行平均,您需要告诉LightningModule
您已经子类化了这样做。有几种不同的方法可以做到这一点,例如:
result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
如文档中所示调用,on_epoch=True
以便在整个时期内平均训练损失。IE:
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return result
- 或者,您可以在其自身上调用该
log
方法:(可选地通过加速器来减少)。LightningModule
self.log("train_loss", loss, on_epoch=True, sync_dist=True)
sync_dist=True
您需要做类似的事情validation_step
来获取聚合的 val-set 指标或在validation_epoch_end
方法中自己实现聚合。
推荐阅读
- r - 我如何从“第 1 周”中删除“周”这个词,以便我只保留数字 1
- macos - 无法使用 macports 模块在 mac 上安装 jq
- python - 无法从 django 中的表中编辑数据
- excel - 从活动工作簿创建一种临时工作簿,然后保存临时工作簿。wb 没有改变行为。wb.. Excel VBA
- r - 如何按月计算数据集中所有年份的百分比偏差
- angular - 角度通用 SSR 动态绝对基准路径
- c# - 将列表对象中的项目附加并返回到对象列表中的字符串数组中
- heroku - 为什么我在 Heroku 上不断收到 H12 错误代码 AKA 503 状态代码?
- javascript - 当 intersectionRatio 为 1 时,Intersection Observation 继续
- python - C# 相当于 Python 的 ThreadPool