python - 如何使用空切片的默认值计算 torch.mean?
问题描述
什么是正确的计算方法torch.mean
,以便在空切片的情况下返回一些默认值?
import torch
a = torch.arange(5, dtype=torch.float)
mask = torch.ones(5, dtype=torch.bool)
mask[2] = 0
mask_empty = torch.zeros(5, dtype=torch.bool)
a_masked = a[mask]
a_empty = a[mask_empty]
m = torch.mean(a_masked)
m_empty = torch.mean(a_empty)
print(m)
print(m_empty)
电流输出:
tensor(2.)
tensor(nan)
所需输出:
tensor(2.)
tensor(0.)
我知道我能做到
m_empty = torch.zeros_like(m_empty) if torch.isnan(m_empty) else m_empty
但这似乎是错误的方法,因为它会强制进行更多的 cpu-gpu 通信,如果可能的话,我想避免这种通信。
此外,这使得代码不像
m_empty = torch.mean(a_empty, default=0)
可用的东西那样干净。
是否有一种干净的 Pythonic 方式来实现torch.mean
[或其他标准函数] 的空切片的默认值?
解决方案
这可能不是最好的方法,也没有你想要的那么干净,但是,实现这一点的一种方法是在计算平均值之前检查张量的长度,
import torch
a = torch.arange(5, dtype=torch.float)
mask = torch.ones(5, dtype=torch.bool)
mask[2] = 0
mask_empty = torch.zeros(5, dtype=torch.bool)
a_masked = a[mask]
a_empty = a[mask_empty]
m = torch.mean(a_masked) if len(a_masked) > 0 else torch.tensor(0.)
m_empty = torch.mean(a_empty) if len(a_empty) > 0 else torch.tensor(0.)
print(m)
print(m_empty)
输出,
tensor(2.)
tensor(0.)
推荐阅读
- java - 如何在sqlite数据库java中存储对象
- java - 如何对除 api 网关之外的所有人隐藏我的公共微服务?
- python-3.x - 尝试通过 kudu 控制台在 python azure 函数中添加模块(speech_py_impl),但面临虚拟环境问题
- c# - CommandBinding 使用自定义 ICommand 实现时未调用已执行的事件处理程序
- python - 在 Python 中使用 OpenCV 确定图像中所有圆形(重叠)斑点的半径
- python - 非均匀间距,带有 numpy.gradient 的多元导数
- jwt - Zoom API:帐户未启用 REST API
- maps - HERE 有 Parcel 边界数据吗?
- java - 如何在 Spring 中将 Redis Http 会话超时设置为无限制
- javascript - 将 python 脚本与 html 的 java 脚本混合