首页 > 解决方案 > 绘制二维高斯的椭圆轮廓

问题描述

假设我有一个带有 pdf 的 2D 高斯

在此处输入图像描述

我想绘制一个对应于水平集(轮廓)的椭圆

在此处输入图像描述

这里我知道我可以用它的特征分解替换精度矩阵来获得 在此处输入图像描述 伽马在哪里 在此处输入图像描述

然后要找到椭圆上点的坐标,我必须这样做

在此处输入图像描述

我尝试绘制此图,但它不起作用。

绘制等高线

from scipy.stats import multivariate_normal
import numpy as np
from numpy.linalg import eigh
import math
import matplotlib.pyplot as plt

# Target distribution
sx2 = 1.0
sy2 = 2.0
rho = 0.6
Sigma = np.array([[sx2, rho*math.sqrt(sx2)*math.sqrt(sy2)], [rho*math.sqrt(sx2)*math.sqrt(sy2), sy2]])
target = multivariate_normal(mean=np.zeros(2), cov=Sigma)
# Two different contours
xy = target.rvs()
xy2 = target.rvs() 
# Values where to plot the density
x, y = np.mgrid[-2:2:0.1, -2:2:0.1]
zz = target.pdf(np.dstack((x, y)))
fig, ax = plt.subplots()
ax.contour(x,y, zz, levels=np.sort([target.pdf(xy), target.pdf(xy2)]))
ax.set_aspect("equal")
plt.show()

上面的代码显示了轮廓 在此处输入图像描述

绘制椭圆

# Find gamma and perform eigendecomposition
gamma = math.log(1 / (4*(np.pi**2)*sx2*sy2*(1 - rho**2)*(target.pdf(xy)**2)))
eigenvalues, P = eigh(np.linalg.inv(Sigma))
# Compute u and v as per link using thetas from 0 to 2pi
thetas = np.linspace(0, 2*np.pi, 100)
uv = (gamma / np.sqrt(eigenvalues)) * np.hstack((np.cos(thetas).reshape(-1,1), np.sin(thetas).reshape(-1, 1)))
# Plot
plt.scatter(uv[:, 0], uv[:, 1])

然而,这显然是行不通的。

在此处输入图像描述

标签: pythonstatisticsgaussiannormal-distribution

解决方案


  1. 您应该在 gamma 中平方 sx2 和 sy2。

  2. 伽玛应该是平方根。

  3. 将得到的椭圆乘以 P^-1 以获得原始坐标系中的点。链接的帖子中提到了这一点。您必须转换回原始坐标系。我实际上不知道如何编码,或者它是否真的有效,所以我把编码留给你。

gamma = math.log(1 / (4*(np.pi**2)*(sx2**2)*(sy2**2)*(1 - rho**2)*(target.pdf(xy)**2)))
eigenvalues, P = eigh(np.linalg.inv(Sigma))
# Compute u and v as per link using thetas from 0 to 2pi
thetas = np.linspace(0, 2*np.pi, 100)
uv = (np.sqrt(gamma) / np.sqrt(eigenvalues)) * np.hstack((np.cos(thetas).reshape(-1,1), np.sin(thetas).reshape(-1, 1)))


orig_coord=np.linalg.inv(P) * uv #I don't how to code this in python

plt.scatter(orig_coord[:,0], orig_coord[:,1])
plt.show()

我尝试对其进行编码:

gamma = math.log(1 / (4*(np.pi**2)*(sx2**2)*(sy2**2)*(1 - rho**2)*(target.pdf(xy)**2)))
eigenvalues, P = eigh(np.linalg.inv(Sigma))
# Compute u and v as per link using thetas from 0 to 2pi
thetas = np.linspace(0, 2*np.pi, 100)
uv = (np.sqrt(gamma) / np.sqrt(eigenvalues)) * np.hstack((np.cos(thetas).reshape(-1,1), np.sin(thetas).reshape(-1, 1)))


orig_coord=np.zeros((100,2))
for i in range(len(uv)):
    orig_coord[i,0]=np.matmul(np.linalg.inv(P), uv[i,:])[0]
    orig_coord[i,1]=np.matmul(np.linalg.inv(P), uv[i,:])[1]

# Plot
plt.scatter(orig_coord[:, 0], orig_coord[:, 1])

gamma1 = math.log(1 / (4*(np.pi**2)*(sx2**2)*(sy2**2)*(1 - rho**2)*(target.pdf(xy2)**2)))
uv1 = (np.sqrt(gamma1) / np.sqrt(eigenvalues)) * np.hstack((np.cos(thetas).reshape(-1,1), np.sin(thetas).reshape(-1, 1)))
orig_coord1=np.zeros((100,2))
for i in range(len(uv)):
    orig_coord1[i,0]=np.matmul(np.linalg.inv(P), uv1[i,:])[0]
    orig_coord1[i,1]=np.matmul(np.linalg.inv(P), uv1[i,:])[1]
plt.scatter(orig_coord1[:, 0], orig_coord1[:, 1])

plt.axis([-2,2,-2,2])

plt.show()

有时绘图不起作用,您会收到错误无效 sqrt,但是当它起作用时,它看起来很好。 在此处输入图像描述


推荐阅读