pytorch - 如何对两个 PyTorch 量化张量进行矩阵相乘?
问题描述
我是张量量化的新手,并尝试做一些简单的事情
import torch
x = torch.rand(10, 3)
y = torch.rand(10, 3)
x@y.T
在 CPU 上运行PyTorch量化张量。我因此尝试了
scale, zero_point = 1e-4, 2
dtype = torch.qint32
qx = torch.quantize_per_tensor(x, scale, zero_point, dtype)
qy = torch.quantize_per_tensor(y, scale, zero_point, dtype)
qx@qy.T # I tried...
..并得到错误
RuntimeError:无法使用来自“QuantizedCPUTensorId”后端的参数运行“aten::mm”。'aten::mm' 仅适用于以下后端:[CUDATensorId, SparseCPUTensorId, VariableTensorId, CPUTensorId, SparseCUDATensorId]。
是不支持矩阵乘法,还是我做错了什么?
解决方案
为量化矩阵实现矩阵乘法并不简单。因此,“常规”矩阵乘法 ( @
) 不支持它(如您的错误消息所示)。
您应该查看量化操作,例如torch.nn.quantized.functional.linear
:
torch.nn.quantized.functional.linear(qx[None,...], qy.T)
推荐阅读
- react-native - React Native 多行 Toast 消息
- amazon-web-services - 为什么 AWS ALB 不提供使用静态 IP?
- vba - 文本文件的基本加密
- r - 按用户和特定日期合并时间序列数据
- mysql - 需要 SQL 查询改进/建议
- php - PHP 致命错误:未捕获的错误:无法使用 WP_Error 类型的对象作为数组
- javascript - SVG:使用 javascript 或 jquery 从 svg 中删除元素
- django - 将外键添加到 Django 导入导出
- java - Android studio java.lang.NoSuchMethodError: 没有静态方法 encodeBase64URLSafeString([B)Ljava/lang/String;
- c# - 信号量不增加或减少