首页 > 解决方案 > 如何在 pytorch ignite 中使用 LBFGS 优化器?

问题描述

我最近开始使用 Ignite,我发现它非常有趣。我想使用torch.optim模块中的 LBFGS 算法作为优化器来训练模型。

这是我的代码:

from ignite.engine import Events, Engine, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import RootMeanSquaredError, Loss
from ignite.handlers import EarlyStopping

 D_in, H, D_out = 5, 10, 1
 model = simpleNN(D_in, H, D_out) # a simple MLP with 1 Hidden Layer
 model.double()
 train_loader, val_loader = get_data_loaders(i)

 optimizer = torch.optim.LBFGS(model.parameters(), lr=1)
 loss_func = torch.nn.MSELoss()  
    
 #Ignite
 trainer = create_supervised_trainer(model, optimizer, loss_func)
 evaluator = create_supervised_evaluator(model, metrics={'RMSE': RootMeanSquaredError(),'LOSS': Loss(loss_func)})
    
 @trainer.on(Events.ITERATION_COMPLETED)
 def log_training_loss(engine):
     print("Epoch[{}] Loss: {:.5f}".format(engine.state.epoch, len(train_loader), engine.state.output))
    
def score_function(engine):
    val_loss = engine.state.metrics['RMSE']
    print("VAL_LOSS: {:.5f}".format(val_loss))
    return -val_loss
    
handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)
    
trainer.run(train_loader, max_epochs=100)

引发的错误是: TypeError: step() missing 1 required positional argument: 'closure'

我知道为 LBFGS 的实现定义一个闭包是必需的,所以我的问题是如何使用 ignite 来做到这一点?还是有另一种方法可以做到这一点?

标签: pytorchpytorch-ignite

解决方案


这样做的方法是这样的:

    from ignite.engine import Engine

    model = ...
    optimizer = torch.optim.LBFGS(model.parameters(), lr=1)
    criterion = 

    def update_fn(engine, batch):
        model.train()
        x, y = batch
        # pass to device if needed as here: https://github.com/pytorch/ignite/blob/40d815930d7801b21acfecfa21cd2641a5a50249/ignite/engine/__init__.py#L45
        def closure():
            y_pred = model(x)
            loss = criterion(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            return loss
    
        optimizer.step(closure)

    trainer = Engine(update_fn)

    # everything else is the same

资源


推荐阅读