python - pytorch闪电模型的输出预测
问题描述
这可能是一个非常简单的问题。我刚开始使用 PyTorch 闪电,无法弄清楚如何在训练后接收模型的输出。
我对 y_train 和 y_test 作为某种数组的预测感兴趣(稍后步骤中的 PyTorch 张量或 NumPy 数组)以使用不同的脚本在标签旁边绘制。
dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset, **train_params)
val_generator = torch.utils.data.DataLoader(val_dataset, **val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0)
trainer.fit(mynet)
在我的闪电模块中,我具有以下功能:
def __init__(self, random_inputs):
def forward(self, x):
def train_dataloader(self):
def val_dataloader(self):
def training_step(self, batch, batch_nb):
def training_epoch_end(self, outputs):
def validation_step(self, batch, batch_nb):
def validation_epoch_end(self, outputs):
def configure_optimizers(self):
我是否需要特定的预测功能,或者是否有任何我看不到的已经实现的方式?
解决方案
我不同意这些答案:OP 的问题似乎集中在他应该如何使用经过闪电训练的模型来获得一般预测,而不是针对训练管道中的特定步骤。在这种情况下,用户不需要靠近 Trainer 对象 - 这些不打算用于一般预测,因此上面的答案鼓励反模式(每次我们都随身携带一个 trainer 对象)想要做一些预测)给任何将来阅读这些答案的人。
我们可以直接trainer
从已定义的 Lightning 模块中获取预测,而不是model = Net(...)
使用已在 Lightning 模块上实现/覆盖 - 这是必需的)。x
model(x)
forward
相反,Trainer.predict()
通常不是使用您训练的模型获得预测的预期方法。Trainer API 提供方法tune
和LightningModule 作为训练管道的一部分fit
,test
在我看来,该predict
方法是为单独数据加载器上的临时预测提供的,作为不太“标准”训练步骤的一部分。
OP 的问题(我需要一个特定的预测函数还是我没有看到任何已经实现的方法?)暗示他们不熟悉该forward()
方法在 PyTorch 中的工作方式,但询问是否已经有一种方法预测他们看不到。因此,完整的答案需要进一步解释该forward()
方法适合预测过程的位置:
之所以model(x)
有效,是因为 Lightning 模块是其子类,torch.nn.Module
它们实现了一个名为的魔术方法__call__()
,这意味着我们可以像调用函数一样调用类实例。__call__()
依次调用forward()
,这就是为什么我们需要在 Lightning 模块中覆盖该方法。
注意。因为forward
只是我们使用时调用的逻辑的一部分,所以除非你有特定的理由偏离,否则model(x)
总是建议使用model(x)
而不是预测。model.forward(x)
推荐阅读
- db2 -
在使用环境的环境中无效。 - kubernetes - 大使显示上游没有健康
- python - 如何删除相互关联的 2 个单独的重复项(PYTHON)
- ios - 如何在 iOS 应用程序中显示 Web 服务调用的进度条
- java - Collection 类的 Spring Boot 自定义序列化程序
- java - 如何在不使用android权限的情况下获取通话记录并选择要显示的号码
- wordpress - nginx-ingress 控制器 wordpress 大文件上传问题
- jquery - 使用 Jquery 更改 Fontawesome 图标
- ios - 由于 Yoga 错误,React Native iOS 构建失败
- php - PHP MySQL 导出数据到excel