首页 > 解决方案 > Pytorch Lightning Trainer 是否使用验证数据来优化模型权重?

问题描述

我目前正在使用Pytorch Forecasting,它大量使用 Pytorch Lightning。在这里,我应用 Pytorch Lightning Trainer来训练一个Temporal Fusion Transformer 模型,大致遵循这个例子的大纲。我粗略的训练代码和模型定义如下:

training = TimeSeriesDataSet(
    df_train[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="target",
    group_ids=["group"],
    max_prediction_length=90,
    min_encoder_length=365 // 2,
    max_encoder_length=365, 
    time_varying_unknown_reals=["target"], 
    time_varying_known_reals=["time_idx"]
)

validation = TimeSeriesDataSet.from_dataset(training, df_train, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 4  
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=res.suggestion(),
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,  
    loss=QuantileLoss(),
    log_interval=10,  
    reduce_on_plateau_patience=4,
    time_varying_reals_encoder=["target"],
    time_varying_reals_decoder=["target"]
)

trainer = pl.Trainer(
    max_epochs=15,
    gpus=0,
    weights_summary="top",
    gradient_clip_val=0.1,
    limit_train_batches=30,
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

trainer.fit(
    tft,
    train_dataloader,
    val_dataloader
)

现在我的问题是,验证数据是否对模型的优化有影响?我一直在玩这个max_prediction_length参数,当我将验证时间窗口设置为更大的时间范围时,模型的性能似乎更好。Pytorch Lightning Trainer 是否使用验证数据来优化模型,还是我遗漏了其他东西?

提前非常感谢!

标签: pythonpytorchforecastingpytorch-lightning

解决方案


推荐阅读