首页 > 解决方案 > 在 pytorch 模块中使用 TorchScript 类作为成员

问题描述

我正在尝试使一些现有的 pytorch 模型支持 TorchScript jit 编译器,但我遇到了非原始类型成员的问题。

这个小例子说明了这个问题:

import torch

@torch.jit.script
class Factory(object):
    def __init__(self):
        pass

    def create(self, x: float) -> torch.Tensor:
        return torch.tensor([x])

class Foo(torch.nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.factory: Factory = Factory()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)

mod = torch.jit.script(Foo())

运行时,jit 编译器给出错误

RuntimeError:
module has no attribute 'factory':
at example.py:17:15
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)
               ~~~~~~~~~~~~ <--- HERE

我已经测试过Factory该类对方法内的 jit 可用forward,但是当我将其存储为成员时它不承认它。为什么是这样?有什么方法可以让 jit 编译器将这种成员保存到编译模块中?

标签: pythonpytorchtorchscript

解决方案


这是 PyTorch 中的一个错误,在您发布问题后不久就解决了:https : //discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645,https: //github.com/pytorch/ pytorch/问题/27495

更新 PyTorch 应该可以修复它。


推荐阅读