python - 如何用其他 pytorch 函数替换 torch.norm?
问题描述
我想torch.norm
使用其他 Pytorch 函数替换该函数。
我能够在不是矩阵torch.norm
的情况下进行替换x
,如下面的代码所示。
import torch
x = torch.randn(9)
out1 = torch.norm(x)
out2 = sum(abs(x)**2)**(1./2)
out1 == out2
>> tensor(True)
但是当 x 是矩阵时,我不知道如何替换它。特别是,我想在我的情况下替换它dim=1 and keepdim=True
。
x = torch.randn([3, 136, 64, 64])
out1 = torch.norm(x, dim=1, keepdim=True)
out2 = ???
out1 == out2
背景:
我正在将 Pytorch 模型转换为 CoreML,但函数中_VF.frobenius_norm
定义的运算符torch.norm
未使用 CoreMLTools 实现。(里面的实现torch.norm
可以在这里找到。)
少数人遇到了这个问题,但 CoreMLTools 仍然不受支持(您可以从这个问题中查看)。所以我想在没有使用的运算符的情况下替换它torch.norm
。
我试过了torch.linalg.norm()
,numpy.linalg.norm
但不支持。
我创建了一个简单的协作笔记本来重现这一点。请使用以下 colab 对其进行测试。 https://colab.research.google.com/drive/11o6rTxHzEgZ_Rc7nFZHd3TvPugybB88h?usp=sharing
解决方案
您可以尝试以下方法:
import torch
x = torch.randn([3, 136, 64, 64])
out1 = torch.norm(x, dim=1, keepdim=True)
out2 = torch.square(x).sum(dim=1, keepdim=True).sqrt()
请注意,由于精度上的小误差,它out1 == out2
不会完全给出。True
您可以检查错误是否按1e-7
for的顺序排列float32
。
在这里,范数是直接使用其数学定义计算的。您可以从 Wolfram MathWorld 中查看此参考以了解更多详细信息。
推荐阅读
- python - 使用 Django Python 下载带有水印的 PDF 格式的动态 html
- http - 是否可以在初始获取请求时升级到 websocket?
- c - 在 atmega328p 的 TIMER1 中使用 ISR 创建 10 秒延迟
- python - 试图在 nic 上加载 xdp(已卸载)
- html - Tailwind css:对齐父div的元素底部
- python - 如何将df的多列从十六进制转换为十进制
- python - NLU Watson API - ApiException:错误:无效请求:内容为空,代码:400
- firebase - 如何在 localhost 上查看从 Firebase auth emulator (web) 发送的确认电子邮件?
- node.js - uri 的值不能从 double 转换为 string
- parsing - 带功能的调车场