python - Python中张量的总变分正则化
问题描述
嗨,我正在尝试为张量或更准确的多通道图像实现总变化函数。我发现对于上面的 Total Variation(如图),有这样的源代码:
def compute_total_variation_loss(img, weight):
tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()
return weight * (tv_h + tv_w)
因为,我是 python 的初学者,所以我不明白索引是如何在图像中引用 i 和 j 的。我还想为 c 添加总变化(除了 i 和 j),但我不知道哪个索引指的是 c。
或者更简洁,如何在 python 中编写以下等式: 在此处输入图像描述
解决方案
此功能假定批处理图像。img
维度的 4 维张量也是如此(B, C, H, W)
(B
是批次中的图像C
数量、颜色通道的数量、H
高度和W
宽度)。
因此,img[0, 1, 2, 3]
是第(2, 3)
一个图像中第二种颜色(RGB 中的绿色)的像素。
在 Python(以及 Numpy 和 PyTorch)中,可以使用符号 选择元素切片i:j
,这意味着i, i + 1, i + 2, ..., j - 1
选择了元素。在您的示例中,:
表示所有元素,1:
表示除第一个之外的所有元素,:-1
表示除最后一个之外的所有元素(负索引向后检索元素)。请参考“在 NumPy 中切片”的教程。
所以img[:,:,1:,:] - img[:,:,:-1,:]
相当于(一批)图像减去它们自己垂直移动一个像素,或者,在你的符号中X(i + 1, j, k) - X(i, j, k)
。然后对张量进行平方 ( .pow(2)
) 并求和 ( .sum()
)。请注意,在这种情况下,总和也超过了批次,因此您收到的是批次的总变化,而不是每个图像的总变化。
推荐阅读
- android - 内联时Kotlin空指针异常但分离时没有
- html - CSS 标题和打字动画
- python - Plotly Dash - 使用 For 循环在选项卡内动态生成多个标题
- python - google.com python 112345678 中的 Python v344 错误
- pytorch - 在自定义数据集上使用 roboflow 对象检测 Yolov4 pytorch 模型时出现值错误
- javascript - forEach 替换 Javascript 中的批量代码的解决方案
- postgresql - 使用 ODBC 和外部 SQL 文件的 Powershell/Postgres
- django - Django 显示图像按钮不适用于 for 循环
- linux - GNU findutils 找不到文件
- python - 在 sqlite3 中选择日期 DD/MM/YYYY