首页 > 解决方案 > 如何使用空切片的默认值计算 torch.mean?

问题描述

什么是正确的计算方法torch.mean,以便在空切片的情况下返回一些默认值?

import torch

a = torch.arange(5, dtype=torch.float)
mask = torch.ones(5, dtype=torch.bool)
mask[2] = 0
mask_empty = torch.zeros(5, dtype=torch.bool)

a_masked = a[mask]
a_empty = a[mask_empty]
m = torch.mean(a_masked)
m_empty = torch.mean(a_empty)

print(m)
print(m_empty)

电流输出:

tensor(2.)
tensor(nan)

所需输出:

tensor(2.)
tensor(0.)

我知道我能做到

m_empty = torch.zeros_like(m_empty) if torch.isnan(m_empty) else m_empty

但这似乎是错误的方法,因为它会强制进行更多的 cpu-gpu 通信,如果可能的话,我想避免这种通信。
此外,这使得代码不像 m_empty = torch.mean(a_empty, default=0)可用的东西那样干净。


是否有一种干净的 Pythonic 方式来实现torch.mean[或其他标准函数] 的空切片的默认值?

标签: pythonpytorch

解决方案


这可能不是最好的方法,也没有你想要的那么干净,但是,实现这一点的一种方法是在计算平均值之前检查张量的长度,

import torch

a = torch.arange(5, dtype=torch.float)
mask = torch.ones(5, dtype=torch.bool)
mask[2] = 0
mask_empty = torch.zeros(5, dtype=torch.bool)
a_masked = a[mask]
a_empty = a[mask_empty]
m = torch.mean(a_masked) if len(a_masked) > 0 else torch.tensor(0.)
m_empty = torch.mean(a_empty) if len(a_empty) > 0 else torch.tensor(0.)

print(m)
print(m_empty)

输出,

tensor(2.)
tensor(0.)

推荐阅读