python - 如何用其他 pytorch 函数替换 torch.sparse?
问题描述
我想使用其他 Pytorch 函数替换 torch.sparse 函数。
i = torch.LongTensor([[0, 1, 1], [2, 0, 2]])
v = torch.FloatTensor([3, 4, 5])
out1 = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()
out2 = ???
out1 == out2
> tensor(True)
背景:
我正在将 Pytorch 模型转换为 CoreML,但在 torch.norm 函数中定义的 sparse_coo_tensor 运算符未使用 CoreMLTools 实现。
少数人会遇到这个问题,但 CoreMLTools 仍然不受支持。所以我想在不使用torch.sparse.FloatTensor 的运算符的情况下替换它。
我试过torch.sparse_coo_tensor
了,但不支持。
我创建了一个简单的协作笔记本来重现这一点。请使用以下 colab 对其进行测试。 https://colab.research.google.com/drive/1TzpeJqEcmCy4IuXhhl6LChFocfZVaR1Q?usp=sharing
我之前在stackoverflow上询问过不同的运算符,所以请参考。 如何用其他 pytorch 函数替换 torch.norm?
解决方案
如果我正确理解 sparse_coo 格式,则 的两行i
只是复制v
. 这意味着您可以改为创建矩阵,例如:
def dense_from_coo(i, v):
rows = i[0].max()+1
cols = i[1].max()+1
out = torch.zeros(rows, cols)
out[i[0], i[1]] = v
return out
print(dense_from_coo(i,v))
>>> tensor([[0., 0., 3.],
[4., 0., 5.]])
print(torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense())
>>> tensor([[0., 0., 3.],
[4., 0., 5.]])
推荐阅读
- java - org.springframework.web.client.ResourceAccessException:“某些 URL”的 GET 请求上的 I/O 错误:收到致命警报:bad_certificate
- python - 多维索引
- python - python - 并行写入数据的单独线程使我的代码变慢 - 但为什么呢?
- python - 在 Aplhabets 上定义轮廓
- android - 将地图转换为对象
- python - 标签编码 n 维分类值
- groovy - 在 Groovy 中循环遍历没有迭代器的对象
- android - LifecycleObserver 使用使用较新 API 的方法产生异常
- java - com.rengwuxian.materialedittext.MaterialEditText 无法转换为 android.view.ViewGroup
- batch-file - 尝试在 Windows 10 中使用 gnu awk 来拆分大型序列文件