首页 > 解决方案 > 使用 matplotlib 覆盖多个颜色图

问题描述

我总共有 16 个颜色图,如下所示:

在此处输入图像描述

在此处输入图像描述

有没有办法在保留颜色的同时覆盖所有地图?这意味着,我想获得由 16 个不同颜色的分布组成的最终图像。我一直在寻找很多,但不幸的是还没有找到任何好的东西。非常感谢!

为了重现,代码如下所示:

import torch
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors


def softmax(logit_map):
    bn, kn, h, w = logit_map.shape
    map_norm = F.softmax(logit_map.reshape(bn, kn, -1), dim=2).reshape(bn, kn, h, w)
    return map_norm

def get_mu_and_prec(part_maps, device, scal):
    """
        Calculate mean for each channel of part_maps
        :param part_maps: tensor of part map activations [bn, n_part, h, w]
        :return: mean calculated on a grid of scale [-1, 1]
        """
    bn, nk, h, w = part_maps.shape
    y_t = torch.linspace(-1., 1., h).reshape(h, 1).repeat(1, w).unsqueeze(-1)
    x_t = torch.linspace(-1., 1., w).reshape(1, w).repeat(h, 1).unsqueeze(-1)
    meshgrid = torch.cat((y_t, x_t), dim=-1).to(device) # 64 x 64 x 2

    mu = torch.einsum('ijl, akij -> akl', meshgrid, part_maps) # bn x nk x 2
    mu_out_prod = torch.einsum('akm,akn->akmn', mu, mu)

    mesh_out_prod = torch.einsum('ijm,ijn->ijmn', meshgrid, meshgrid)
    stddev = torch.einsum('ijmn,akij->akmn', mesh_out_prod, part_maps) - mu_out_prod

    a_sq = stddev[:, :, 0, 0]
    a_b = stddev[:, :, 0, 1]
    b_sq_add_c_sq = stddev[:, :, 1, 1]
    eps = 1e-12

    a = torch.sqrt(a_sq + eps)  # Σ = L L^T Prec = Σ^-1  = L^T^-1 * L^-1  ->looking for L^-1 but first L = [[a, 0], [b, c]
    b = a_b / (a + eps)
    c = torch.sqrt(b_sq_add_c_sq - b ** 2 + eps)
    z = torch.zeros_like(a)

    det = (a * c).unsqueeze(-1).unsqueeze(-1)
    row_1 = torch.cat((c.unsqueeze(-1), z.unsqueeze(-1)), dim=-1).unsqueeze(-2)
    row_2 = torch.cat((-b.unsqueeze(-1), a.unsqueeze(-1)), dim=-1).unsqueeze(-2)
    L_inv = scal / (det + eps) * torch.cat((row_1, row_2), dim=-2)  # L^⁻1 = 1/(ac)* [[c, 0], [-b, a]
    return mu, L_inv

def get_heat_map(mu, L_inv, device):
    h, w, nk = 64, 64, L_inv.shape[1]

    y_t = torch.linspace(-1., 1., h).reshape(h, 1).repeat(1, w).unsqueeze(-1)
    x_t = torch.linspace(-1., 1., w).reshape(1, w).repeat(h, 1).unsqueeze(-1)

    y_t_flat = y_t.reshape(1, 1, 1, -1)
    x_t_flat = x_t.reshape(1, 1, 1, -1)

    mesh = torch.cat((y_t_flat, x_t_flat), dim=-2).to(device)
    dist = mesh - mu.unsqueeze(-1)
    proj_precision = torch.einsum('bnik, bnkf -> bnif', L_inv, dist) ** 2  # tf.matmul(precision, dist)**2
    proj_precision = torch.sum(proj_precision, -2)  # sum x and y axis
    heat = 1 / (1 + proj_precision)
    heat = heat.reshape(-1, nk, h, w)  # bn number parts width height

    return heat

color_list = ['black', 'gray', 'brown', 'chocolate', 'orange', 'gold', 'olive', 'lawngreen', 'aquamarine', 
              'dodgerblue', 'midnightblue', 'mediumpurple', 'indigo', 'magenta', 'pink', 'springgreen']

fmap = torch.randn(1, 16, 64, 64)
fmap_norm = softmax(fmap)
mu, L_inv = get_mu_and_prec(fmap_norm, 'cpu', scal=5.)
heat_map = get_heat_map(mu, L_inv, "cpu")

for i in range(16):
    cmap = colors.LinearSegmentedColormap.from_list('my_colormap',
                                           ['white', color_list[i]],
                                           256)
    plt.imshow(heat_map[0][i].numpy(), cmap=cmap)
    plt.show()

标签: pythonopencvmatplotlibimshow

解决方案


推荐阅读