tensorflow - 没有onnx,如何手动将pytorch模型转成tensorflow模型?
问题描述
由于 ONNX 支持的模型有限,我尝试通过直接分配参数来进行这种转换,但获得的 tensorflow 模型未能显示出预期的精度。详细说明如下:
- 源模型是在 MNIST 数据集上训练的 Lenet。
- 我首先通过 model.named_parameters() 提取每个模块及其参数并将它们保存到一个字典中,其中键是模块的名称,值是参数
- 然后,我构建并启动了一个具有相同架构的 tensorflow 模型
- 最后,我将pytroch模型的每一层的参数分配给tensorflow模型
然而,获得的 tensorflow 模型的准确率只有 20% 左右。因此,我的问题是是否可以通过这种方法转换 pytorch 模型?如果是,导致不良结果的可能问题是什么?如果不是,请解释原因。
PS:假设分配程序是正确的。
解决方案
正如 jodag 的评论所提到的,Tensorflow 和 PyTorch 中的运算符表示之间存在许多差异,这可能会导致您的工作流程出现差异。
我们建议使用以下方法:
- 使用PyTorch 中的 ONNX导出器将模型导出为 ONNX 格式。
import torch.onnx
# Argument: model is the PyTorch model
# Argument: dummy_input is a torch tensor
torch.onnx.export(model, dummy_input, "LeNet_model.onnx")
- 使用onnx-tensorflow 后端将 ONNX 模型转换为 Tensorflow。
import onnx
from onnx_tf.backend import prepare
onnx_model = onnx.load("LeNet_model.onnx") # load onnx model
tf_rep = prepare(onnx_model) # prepare tf representation
tf_rep.export_graph("LeNet_model.pb") # export the model
推荐阅读
- c# - asp:Button stopped working without changing the code
- azure-stream-analytics - 缺少字段的 Azure 流分析默认字段值
- python - Python pccolor plot 有两个和两个 x 轴描述
- r - 如何按组测试一行中的值是否与所有先前行相比是唯一的,并计算不同值的数量
- c# - 将字典与可能重复的键组合到另一个包含最大值的字典中
- go - 在 Couchbase 中的字符串中搜索字符串
- android - Docker:将主机中的目录安装到容器中会导致错误
- sharepoint - MS Graph API ODATA 查询 orderby 不适用于 SharePoint
- android - 图像视图未显示在屏幕叠加层中
- java - Jackson 全局设置将数组反序列化为自定义列表实现