首页 > 解决方案 > 如何将 Pytorch 模型转换为 ONNX 模型?

问题描述

我有一个模型,其中包含两部分 encoderdecoder. 它们每个都是 PyTorch 模型。

我推断它们如下。

features = encoder(input)
output = decoder(features)

我想将它们转换成一个 ONNX 模型而不是两个 ONNX 模型。我该怎么做?

标签: pythondeep-learningpytorchonnx

解决方案


通常,您应该构建一个继承nn.Module组合多个模型的新类。在这种情况下,您有一个简单的前馈网络,因此我们可以使用nn.Sequential便利类。

例如

model = torch.nn.Sequential([encoder, decoder])

# inference example
model.eval()
output = model(input)

# save model to onnx file ...

推荐阅读