python - Torchscript 与张量列表的 torch.cat 不兼容
问题描述
在 torchscript 中使用时,Torch.cat 对张量列表抛出错误
这是重现错误的最小可重现示例
import torch
import torch.nn as nn
"""
Smallest working bug for torch.cat torchscript
"""
class Model(nn.Module):
"""dummy model for showing error"""
def __init__(self):
super(Model, self).__init__()
pass
def forward(self):
a = torch.rand([6, 1, 12])
b = torch.rand([6, 1, 12])
out = torch.cat([a, b], axis=2)
return out
if __name__ == '__main__':
model = Model()
print(model()) # works
torch.jit.script(model) # throws error
预期的结果将是 torch.cat 的 torchscript 输出。这是提供的错误消息:
File "/home/anil/.conda/envs/rnn/lib/python3.7/site-packages/torch/jit/__init__.py", line 1423, in _create_methods_from_stubs
self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError:
Arguments for call are not valid.
The following operator variants are available:
aten::cat(Tensor[] tensors, int dim=0) -> (Tensor):
Keyword argument axis unknown.
aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!)):
Argument out not provided.
The original call is:
at smallest_working_bug_torch_cat_torchscript.py:19:14
def forward(self):
a = torch.rand([6, 1, 12])
b = torch.rand([6, 1, 12])
out = torch.cat([a, b], axis=2)
~~~~~~~~~ <--- HERE
return out
请让我知道此问题的修复或解决方法。
谢谢!
解决方案
更改axis
以dim
修复错误,原始解决方案已在此处发布
推荐阅读
- python - Pandas Dataframes:当行包含不同国家/地区时,组合来自两个全球数据集的列
- kubernetes - 关于 GKE 优化的图像强化
- multithreading - 如何使用 PyQt 线程长时间睡眠?
- git - git agony 第 2 部分:当 git 不会提交、拉取或推送文件时是什么意思?
- stata - 如果在Stata的两个不同时期观察到变量的值,如何生成指标
- python - 使用 Psycopg2 执行复制到数据库后文件对象为空
- javascript - squarespace 7.1 版代码注入不起作用
- java - 在 Java 中仅打印非负余额
- google-api - 谷歌 API 密钥无效
- lua - 使用主表内的函数编辑主表内的子表