pytorch - 如何在 pytorch ignite 中使用 LBFGS 优化器?
问题描述
我最近开始使用 Ignite,我发现它非常有趣。我想使用torch.optim
模块中的 LBFGS 算法作为优化器来训练模型。
这是我的代码:
from ignite.engine import Events, Engine, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import RootMeanSquaredError, Loss
from ignite.handlers import EarlyStopping
D_in, H, D_out = 5, 10, 1
model = simpleNN(D_in, H, D_out) # a simple MLP with 1 Hidden Layer
model.double()
train_loader, val_loader = get_data_loaders(i)
optimizer = torch.optim.LBFGS(model.parameters(), lr=1)
loss_func = torch.nn.MSELoss()
#Ignite
trainer = create_supervised_trainer(model, optimizer, loss_func)
evaluator = create_supervised_evaluator(model, metrics={'RMSE': RootMeanSquaredError(),'LOSS': Loss(loss_func)})
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
print("Epoch[{}] Loss: {:.5f}".format(engine.state.epoch, len(train_loader), engine.state.output))
def score_function(engine):
val_loss = engine.state.metrics['RMSE']
print("VAL_LOSS: {:.5f}".format(val_loss))
return -val_loss
handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)
trainer.run(train_loader, max_epochs=100)
引发的错误是:
TypeError: step() missing 1 required positional argument: 'closure'
我知道为 LBFGS 的实现定义一个闭包是必需的,所以我的问题是如何使用 ignite 来做到这一点?还是有另一种方法可以做到这一点?
解决方案
这样做的方法是这样的:
from ignite.engine import Engine
model = ...
optimizer = torch.optim.LBFGS(model.parameters(), lr=1)
criterion =
def update_fn(engine, batch):
model.train()
x, y = batch
# pass to device if needed as here: https://github.com/pytorch/ignite/blob/40d815930d7801b21acfecfa21cd2641a5a50249/ignite/engine/__init__.py#L45
def closure():
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
return loss
optimizer.step(closure)
trainer = Engine(update_fn)
# everything else is the same
推荐阅读
- r - 如何在 geom_violin 中复制 vioplot?
- python - 合并时保留 csv 文件的副本
- entity-framework-core - Entity Framework Core:有没有办法重用子查询而不是不使用 DRY 的 Include?
- javascript - 是否可以通过 jQuery 全局更改 CSS 变量?
- kubernetes - k8 pod 优先级和测试
- python - 由 bs4 和正则表达式元素创建的 pandas 对象被打印为 python 列表
- java - Mockito/PowerMockito 检查退货
- python - 如何将雪人形状的轮廓分成两个圆圈
- java - Android MVVM 存储库和 ViewModel 问题
- javascript - 有人可以向我解释这段代码吗?关于有条件地禁用输入