python - 在 pytorch 模块中使用 TorchScript 类作为成员
问题描述
我正在尝试使一些现有的 pytorch 模型支持 TorchScript jit 编译器,但我遇到了非原始类型成员的问题。
这个小例子说明了这个问题:
import torch
@torch.jit.script
class Factory(object):
def __init__(self):
pass
def create(self, x: float) -> torch.Tensor:
return torch.tensor([x])
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.factory: Factory = Factory()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
mod = torch.jit.script(Foo())
运行时,jit 编译器给出错误
RuntimeError:
module has no attribute 'factory':
at example.py:17:15
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
~~~~~~~~~~~~ <--- HERE
我已经测试过Factory
该类对方法内的 jit 可用forward
,但是当我将其存储为成员时它不承认它。为什么是这样?有什么方法可以让 jit 编译器将这种成员保存到编译模块中?
解决方案
这是 PyTorch 中的一个错误,在您发布问题后不久就解决了:https : //discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645,https: //github.com/pytorch/ pytorch/问题/27495。
更新 PyTorch 应该可以修复它。
推荐阅读
- azure - Azure 存储帐户在没有“blob.core.windows.net”端点的情况下创建
- heroku - ERR_CERT_DATE_INVALID 使用 Heroku ACM
- python - 无法通过 HTML 表单更改图像,但我可以从 Django 管理面板更改它们
- git - git 命令失败,退出代码 128 警告:url 在其用户名组件中包含换行符然后致命:无法解析凭据 url
- horizontal-scrolling - element.scrollTo 平滑滚动在 Safari 中的可捕捉可滚动容器中不起作用
- asp.net-mvc - 身份验证错误 - IIS 服务器上的 ASP.NET Web API 和 Angular 7
- kotlin - “”的名字是什么?Kotlin 中“var button: Button? = null”表达式中的运算符/符号?
- javascript - 如何读取在另一个控制器中标定的 json
- python - 无法使用 GitHub 操作从 pytest 中的存储库导入脚本
- javascript - 根据父窗口内部高度将Popup子窗口居中对齐