python - 使用 Pytorch Lightning 时如何将指标(例如验证损失)记录到 TensorBoard?
问题描述
我正在使用 Pytorch Lightning 来训练我的模型(在 GPU 设备上,使用 DDP),TensorBoard 是 Lightning 使用的默认记录器。
我的代码设置为分别记录每个训练和验证步骤的训练和验证损失。
class MyLightningModel(pl.LightningModule):
def training_step(self, batch):
x, labels = batch
out = self(x)
loss = F.mse_loss(out, labels)
self.log("train_loss", loss)
return loss
def validation_step(self, batch):
x, labels = batch
out = self(x)
loss = F.mse_loss(out, labels)
self.log("val_loss", loss)
return loss
TensorBoard 在选项卡中正确绘制train_loss
和val_loss
图表SCALERS
。但是,在HPARAMS
左侧栏中的选项卡中,仅hp_metric
在 下可见Metrics
。
但是,在HPARAMS
左侧栏中的选项卡中,仅hp_metric
在 下可见Metrics
。
我们如何添加train_loss
和val_loss
到该Metrics
部分?这样,我们将能够使用val_loss
inPARALLEL COORDINATES VIEW
而不是hp_metric
。
使用 Pytorch 1.8.1、Pytorch Lightning 1.2.6、TensorBoard 2.4.1
解决方案
- 您可以使用
self.logger.log_hyperparams
方法在 tensorboard 中记录超参数和指标。(请参阅pytorch 闪电 tensorboard 文档) - 当
self.log
且仅当您在metric
. (参见pytorch tensorboard 文档)
示例代码(完整代码):
class BasicModule(LightningModule):
def __init__(self, lr=0.01):
super().__init__()
self.model = models.resnet18(pretrained=False)
self.criterion = nn.CrossEntropyLoss()
self.lr = lr
self.save_hyperparameters()
metric = MetricCollection({'top@1': Accuracy(top_k=1), 'top@5': Accuracy(top_k=5)})
self.train_metric = metric.clone(prefix='train/')
self.valid_metric = metric.clone(prefix='valid/')
def on_train_start(self) -> None:
# log hyperparams
self.logger.log_hyperparams(self.hparams, {'train/top@1': 0, 'train/top@5': 0, 'valid/top@1': 0, 'valid/top@5': 0})
return super().on_train_start()
def training_step(self, batch, batch_idx, optimizer_idx=None):
return self.shared_step(*batch, self.train_metric)
def validation_step(self, batch, batch_idx):
return self.shared_step(*batch, self.valid_metric)
def shared_step(self, x, y, metric):
y_hat = self.model(x)
loss = self.criterion(y_hat, y)
self.log_dict(metric(y_hat, y), prog_bar=True)
return loss
if __name__ == '__main__':
# default_hp_metric=False
logger = loggers.TensorBoardLogger('', 'lightning_logs', default_hp_metric=False)
trainer = Trainer(max_epochs=2, gpus='0,', logger=logger, precision=16)
推荐阅读
- javascript - 调用pickdate时如何添加条件?
- typescript - redux 的 typescript types 中的 CombinedState 类型的原因
- c# - 将匿名类型转换为字典
- python - 如何在 python 中使用正则表达式从多行字符串中删除特定字符
- android - react-native-firebase react-native链接问题,如何解决?
- python - Scrapy & Selenium:从文本文件加载 starturl 不起作用
- python - 规范化在 Python 中的表现如何?
- python - 如何将从(对象检测)裁剪的检测到的面部保存到其特定创建的文件夹中?
- javascript - 搜索一个或多个州的公园
- php - 如何在电报机器人使用php的评论中使用回调