pytorch - 如何找到二维激活图(pytorch)的均值和协方差
问题描述
我有一个形状为 [h, w] 的张量,它由一个标准化的二维激活图组成。考虑到这是某种分布,我想在 pytorch 中找到这个激活图中的均值和协方差。有没有一种有效的方法来做到这一点?
解决方案
您可以使用以下代码,其中activation_map
是 shape 的张量(h,w)
,具有非负元素,并且已归一化(activation_map.sum()
为 1):
activation_map = torch.tensor(
[[0.2, 0.1, 0.0],
[0.1, 0.2, 0.4]])
h, w = activation_map.shape
range_h = torch.arange(h)
range_w = torch.arange(w)
idxs = torch.stack([
range_w[None].repeat(h, 1),
range_h[:, None].repeat(1, w)
])
map_flat = activation_map.view(-1)
idxs_flat = idxs.reshape(2, -1).T
mean = (map_flat[:, None] * idxs_flat).sum(0)
mats = idxs_flat[:, :, None] @ idxs_flat[:, None, :]
second_moments = (map_flat[:, None, None] * mats).sum(0)
covariance = second_moments - mean[:, None] @ mean[None]
# mean:
# tensor([1.1000, 0.7000])
# covariance:
# tensor([[0.6900, 0.2300],
# [0.2300, 0.2100]])
推荐阅读
- javascript - ES6 to ES3 translation for MediaWiki (foreach() method and spreading operator)
- swift - How to init with a generic NSFetchedResultsController?
- android - image asset problem (Background Problem) - what should I do?
- go - 获取 Gin 中 POST HTML 表单的所有内容
- android - 获取用户位置并将其显示在地图上(未显示给我的地图)
- ruby-on-rails - Rails 6: How do I remove a record, associated through a join table without removing it from the associated table?
- css - CSS - Chrome 和 Safari 的尺寸不同
- python - Comparing fluorescence intensity of finger print residue after 5 contacts
- javascript - Firebase 云功能:选择要上传的额外文件
- json - 从 Scala 中的 URL 获取 json