首页 > 解决方案 > 如何找到二维激活图(pytorch)的均值和协方差

问题描述

我有一个形状为 [h, w] 的张量,它由一个标准化的二维激活图组成。考虑到这是某种分布,我想在 pytorch 中找到这个激活图中的均值和协方差。有没有一种有效的方法来做到这一点?

标签: pytorchmeandistributioncovariance

解决方案


您可以使用以下代码,其中activation_map是 shape 的张量(h,w),具有非负元素,并且已归一化(activation_map.sum()为 1):

activation_map = torch.tensor(
    [[0.2, 0.1, 0.0],
     [0.1, 0.2, 0.4]])
h, w = activation_map.shape

range_h = torch.arange(h)
range_w = torch.arange(w)
idxs = torch.stack([
  range_w[None].repeat(h, 1),
  range_h[:, None].repeat(1, w)
  ])
map_flat = activation_map.view(-1)
idxs_flat = idxs.reshape(2, -1).T
mean = (map_flat[:, None] * idxs_flat).sum(0)
mats = idxs_flat[:, :, None] @ idxs_flat[:, None, :]
second_moments = (map_flat[:, None, None] * mats).sum(0)
covariance = second_moments - mean[:, None] @ mean[None]

# mean:
# tensor([1.1000, 0.7000])
# covariance:
# tensor([[0.6900, 0.2300],
#         [0.2300, 0.2100]])

推荐阅读