python - 计算火炬张量的协方差(二维特征图)
问题描述
我有一个形状为(batch_size、number_maps、x_val、y_val)的火炬张量。张量使用 sigmoid 函数进行归一化,因此在 range 内[0, 1]
。我想找到每个地图的协方差,所以我想要一个形状为 (batch_size, number_maps, 2, 2) 的张量。据我所知,没有torch.cov()
像 numpy 那样的功能。如何有效地计算协方差而不将其转换为 numpy?
编辑:
def get_covariance(tensor):
bn, nk, w, h = tensor.shape
tensor_reshape = tensor.reshape(bn, nk, 2, -1)
x = tensor_reshape[:, :, 0, :]
y = tensor_reshape[:, :, 1, :]
mean_x = torch.mean(x, dim=2).unsqueeze(-1)
mean_y = torch.mean(y, dim=2).unsqueeze(-1)
xx = torch.sum((x - mean_x) * (x - mean_x), dim=2).unsqueeze(-1) / (h*w - 1)
xy = torch.sum((x - mean_x) * (y - mean_y), dim=2).unsqueeze(-1) / (h*w - 1)
yx = xy
yy = torch.sum((y - mean_y) * (y - mean_y), dim=2).unsqueeze(-1) / (h*w - 1)
cov = torch.cat((xx, xy, yx, yy), dim=2)
cov = cov.reshape(bn, nk, 2, 2)
return cov
我现在尝试了以下方法,但我很确定它不正确。
解决方案
你可以试试 Github 上推荐的功能:
def cov(x, rowvar=False, bias=False, ddof=None, aweights=None):
"""Estimates covariance matrix like numpy.cov"""
# ensure at least 2D
if x.dim() == 1:
x = x.view(-1, 1)
# treat each column as a data point, each row as a variable
if rowvar and x.shape[0] != 1:
x = x.t()
if ddof is None:
if bias == 0:
ddof = 1
else:
ddof = 0
w = aweights
if w is not None:
if not torch.is_tensor(w):
w = torch.tensor(w, dtype=torch.float)
w_sum = torch.sum(w)
avg = torch.sum(x * (w/w_sum)[:,None], 0)
else:
avg = torch.mean(x, 0)
# Determine the normalization
if w is None:
fact = x.shape[0] - ddof
elif ddof == 0:
fact = w_sum
elif aweights is None:
fact = w_sum - ddof
else:
fact = w_sum - ddof * torch.sum(w * w) / w_sum
xm = x.sub(avg.expand_as(x))
if w is None:
X_T = xm.t()
else:
X_T = torch.mm(torch.diag(w), xm).t()
c = torch.mm(X_T, xm)
c = c / fact
return c.squeeze()
推荐阅读
- c# - Kaizala API:向成员发送消息
- html - HTML数字类型输入递增/递减按钮不断增加数字
- swift - 在 SwiftUI 中使用 EnvironmentObject 在视图之间传递数据时遇到问题
- r - 检查多个变量是否存在相同的值
- javascript - 从具有相同属性名称的对象反应填充对象
- java - Spring Boot - 如何从多个自定义 yml 中读取属性
- terraform - 如何使用terraform根据标签获取所有iam_users
- stanza.io - Stanzjs 如何接收消息
- android - android ftrace:current_tracer 无法写入,但我有可用的跟踪器
- firebase - 如何使用 flutter_geofire 将实时位置发送到 Cloud Firestore?