python - 在pytorch中将3维蒙版应用于RGB的矢量化方法
问题描述
我有一个表示 RGB 图像的 HxWx3 张量和一个 HxWx3 掩码(布尔)张量作为输入。假设对于掩码张量中的每个 (i,j),都只有一个真值(即 R\G\B 中的一个恰好打开)。我想将掩码应用于图像以产生 HxW(或 HxWx1)张量 V,其中 V[i,j]='根据掩码匹配的 R\G\B 值'。
使用问题将二进制掩码应用于带有 numpy 的 RGB 图像,我能够实现以下目标:
>>> X*mask
tensor([[[ 9., 10.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 20.]],
[[ 0., 0.],
[30., 0.]]])
但如前所述,我想要一个昏暗的 HxW 而不是 HxWx3 作为结果。
解决方案
假设对于每个 i,j 只保留一个 R/G/B 值,您可以简单地执行以下操作:
(X*mask).sum(axis=2)
这应该会给你你想要的(HxW)输出。