首页 > 解决方案 > 如何使用 TFT 模型正确修复 pytorch 预测中的随机种子?

问题描述

Pytorch-forecasting 是一个用于时间序列预测的库。我想知道如何修复种子以获得我的实验的可重复性。现在我在训练开始之前使用这个功能。

import pytorch_lightning as pl
pl.seed_everything(42)

但是,它不起作用。这是由 dropout 引起的,因为当我设置 dropout = 0 时它已经起作用了。

    # configure network and trainer
pl.seed_everything(42)
trainer = pl.Trainer(
    deterministic=True,
    gpus=[0],
    gradient_clip_val=0.1,
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=128,
    attention_head_size=2,
    dropout=0.3,  
    hidden_continuous_size=32,  
    output_size=7,  
    loss=QuantileLoss(),
    reduce_on_plateau_patience=4,
)

trainer.fit(
    tft,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)

那么,我该如何解决这个问题呢?

谢谢,

标签: pytorch-lightning

解决方案


推荐阅读