pytorch - 使用 Dev Pytorch 1.0 将 Pytorch 模型加载到 C++ 中
问题描述
Pytorch 1.0 具有将模型转换为 Torch 脚本程序(以某种方式序列化)的功能,以使其能够在 C++ 中执行而无需 Python 依赖。
详细信息在本教程中。 https://pytorch.org/tutorials/advanced/cpp_export.html
这是这样做的:
import torch
import torchvision
# An instance of your model.
model = A UNET MODEL FROM FASTAI which has hooks as required by UNET
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
在我的用例中,我使用 UNET 模型进行语义分割。但是,我使用这种方法跟踪模型,我得到以下错误。
Forward or backward hooks can't be compiled
UNET 模型使用挂钩来保存在网络中的后续层使用的中间特征。有办法解决吗?或者这仍然是这种新方法的一个限制,它不能与使用此类钩子的模型一起使用。
解决方案
如果您可以使用 Pytorch hub 中的 UNET 模型。它将与 TorchScript 一起使用。
import torch
# downloading the model from torchhub
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
in_channels=3, out_channels=1, init_features=32, pretrained=True)
# downloading the sample
import urllib
url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
# reading the sample and some prerequisites for transformation
import numpy as np
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
preprocess = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=m, std=s),])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
# creating the trace
traced_module = torch.jit.trace(model,input_batch)
# running the trace
traced_module(input_batch)
PS:torch.jit.trace/torch.jit.script 都不支持所有的torch 功能,因此将它们与外部库一起使用总是很棘手。
推荐阅读
- python - 使用正则表达式从文本文件中获取一个数字
- javascript - 使用谷歌弹出窗口登录 Firebase 加载时间过长
- python - Numpy数组的最小差异
- eclipse - 是否可以在 Eclipse 运行配置中跳过 maven-javadoc-plugin 执行?
- node.js - 异步加载画布对象的图像
- python - 替换单词和字符串 pandas
- acumatica - 如何将站点 ID (Whse) 选择限制为仅当前分支?
- prolog - 列出大小乘法
- php - 我如何从 SQL 字符串调用 php Require
- metal - 是否可以同时在两个或多个 GPU 上运行 Metal 代码?