首页 > 解决方案 > 在空间轴上计算 Pytorch 中的方差

问题描述

我正在尝试计算 Pytorch 中的方差,但无法在多个轴上进行。

我在 Tensorflow 中做过类似的事情,但无法在 Pytorch 上做,因为 torch.var 函数将 int 作为维度而不是轴。 下面的代码是通道最后一个代码,我希望 axes=[2,3]

Lambda(lambda x: tf.nn.moments(x, axes=[1, 2]))

例如,如果 input_dims = (5, 10, 25, 25) 那么 output_dims 应该是 (5,10, 1, 1)。

标签: tensorflowpytorchvariance

解决方案


您可以做的一件事是在应用该方法tensor.view()之前将所有要计算方差的维度展平为一个维度:var()

torch.var(x.view(x.shape[0], x.shape[1], 1, -1,), dim=3, keepdim=True)

我曾经keepdim=True保留我们计算方差的维度以获得所需的输出形状。


推荐阅读