首页 > 解决方案 > Python中张量的总变分正则化

问题描述

公式

嗨,我正在尝试为张量或更准确的多通道图像实现总变化函数。我发现对于上面的 Total Variation(如图),有这样的源代码:

def compute_total_variation_loss(img, weight):      
    tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
    tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()    
    return weight * (tv_h + tv_w)

因为,我是 python 的初学者,所以我不明白索引是如何在图像中引用 i 和 j 的。我还想为 c 添加总变化(除了 i 和 j),但我不知道哪个索引指的是 c。

或者更简洁,如何在 python 中编写以下等式: 在此处输入图像描述

标签: pythonimageimage-processingpytorchtensor

解决方案


此功能假定批处理图像。img维度的 4 维张量也是如此(B, C, H, W)B是批次中的图像C数量、颜色通道的数量、H高度和W宽度)。

因此,img[0, 1, 2, 3]是第(2, 3)一个图像中第二种颜色(RGB 中的绿色)的像素。

在 Python(以及 Numpy 和 PyTorch)中,可以使用符号 选择元素切片i:j,这意味着i, i + 1, i + 2, ..., j - 1选择了元素。在您的示例中,:表示所有元素1:表示除第一个之外的所有元素,:-1表示除最后一个之外的所有元素(负索引向后检索元素)。请参考“在 NumPy 中切片”的教程。

所以img[:,:,1:,:] - img[:,:,:-1,:]相当于(一批)图像减去它们自己垂直移动一个像素,或者,在你的符号中X(i + 1, j, k) - X(i, j, k)。然后对张量进行平方 ( .pow(2)) 并求和 ( .sum())。请注意,在这种情况下,总和也超过了批次,因此您收到的是批次的总变化,而不是每个图像的总变化。


推荐阅读