首页 > 解决方案 > Matplotlib imshow() 给出一个水平翻转的密度图

问题描述

我对绘制双变量正态的混合物很感兴趣。

def func(x):
    cat = tfd.Categorical(probs=np.array([.5, .5],dtype=NP_DTYPE))
    comps = [tfd.MultivariateNormalDiag(loc=np.array([-5.0, -5.0],dtype=NP_DTYPE), scale_diag=tf.ones(2,dtype=DTYPE)*.1),
         tfd.MultivariateNormalDiag(loc=np.array([5.0, 5.0],dtype=NP_DTYPE), scale_diag=tf.ones(2,dtype=DTYPE)*.1)]

    mix = tfd.Mixture(cat=cat, components=comps)
    return mix.prob(x)

这是两个二元正态的混合。一个中心在 [5,5],另一个中心在 [-5,-5];两者都有对角线协方差矩阵,沿对角线为 0.1。每个具有相同的混合重量。

我的绘图代码是这样的

# make these smaller to increase the resolution
dx, dy = 0.1, 0.1

x = np.arange(-10.0, 10.0, dx)
y = np.arange(-10.0, 10.0, dy)

X, Y = np.meshgrid(x, y)
Z = np.concatenate((X.reshape(-1,1),Y.reshape(-1,1)),axis=1)


extent = np.min(x), np.max(x), np.min(y), np.max(y)
fig = plt.figure(frameon=True)



Z2 = tf.log(func(Z) + 1e-6)
Z2 = sess.run(Z2)
Z2 = Z2.reshape(int(np.sqrt(Z2.shape[0])),int(np.sqrt(Z2.shape[0])))

im2 = plt.imshow(Z2, cmap=plt.cm.viridis, alpha=.9, interpolation='bilinear',
                 extent=extent)
plt.colorbar()
plt.show()

我扁平化网格的原因是我想实现一个通用的密度绘图,以便它可以绘制任何复杂的二维分布。(任意二维分布,输入形状为 [N,D];N 是点数,D 是每个点的维度)

然而,这给出了一个奇怪的情节

在此处输入图像描述

由于高温区域应在 [5,5] 和 [-5,-5] 附近,因此水平翻转了图

任何帮助解决这个问题?(因为 imshow() 是一种黑盒,我需要密度函数采用特定形式的输入;我不知道如何解决这个问题)

标签: pythonmatplotlib

解决方案


添加origin="lower"到您的 imshow 图的论点。

默认值为"upper",这对于绘制图像很有意义,因为它们通常从左上角开始。但是,对于基于矩阵的绘图,您通常需要origin="lower",否则您的绘图会被翻转。

im2 = plt.imshow(Z2, cmap=plt.cm.viridis, alpha=.9, interpolation='bilinear',
                 extent=extent,origin='lower')
plt.colorbar()
plt.show()

在此处输入图像描述


推荐阅读