numpy - Distributions.jl 不接受我的(非正定)协方差函数(而 numpy 接受)
问题描述
我正在尝试在krasserm之后在 Julia 中实现高斯过程(GP) 。
numpy 的实现就像一个魅力:
import numpy as np
def kernel(X1, X2, l=1.0, sigma_f=1.0):
''' Isotropic squared exponential kernel. Computes a covariance matrix from points in X1 and X2. Args: X1: Array of m points (m x d). X2: Array of n points (n x d). Returns: Covariance matrix (m x n). '''
sqdist = np.sum(X1**2, 1).reshape(-1, 1) + np.sum(X2**2, 1) - 2 * np.dot(X1, X2.T)
return sigma_f**2 * np.exp(-0.5 / l**2 * sqdist)
X = np.arange(-1, 1, 0.1).reshape(-1, 1)
mu = np.zeros(X.shape)
cov = kernel(X, X)
samples = np.random.multivariate_normal(mu.ravel(), cov, 3)
虽然我在使用 Distributions.jl 创建多元正态时出错。
这是我尝试过的:
using LinearAlgebra
using Distributions
function kernel(x₁, x₂; l=1.0, σ=1.0)
t₁ = sum(x₁.^2, dims=2)
t₂ = sum(x₂.^2, dims=2)
t₃ = 2*x₁*x₂'
t = (t₁ .+ t₂') - t₃
return σ^2 * exp.(-0.5/l^2 * t)
end
x = -1:0.1:0.9
μ = zeros(length(x))
σ = kernel(x, x)
D = MvNormal(μ, σ)
问题似乎是我的协方差不是正(半)确定的。
朱莉娅:
σ
20×20 Array{Float64,2}:
1.0 0.995012 0.980199 0.955997 … 0.235746 0.197899 0.164474
0.995012 1.0 0.995012 0.980199 0.278037 0.235746 0.197899
0.980199 0.995012 1.0 0.995012 0.324652 0.278037 0.235746
0.955997 0.980199 0.995012 1.0 0.375311 0.324652 0.278037
0.923116 0.955997 0.980199 0.995012 0.429557 0.375311 0.324652
0.882497 0.923116 0.955997 0.980199 … 0.486752 0.429557 0.375311
0.83527 0.882497 0.923116 0.955997 0.546074 0.486752 0.429557
0.782705 0.83527 0.882497 0.923116 0.606531 0.546074 0.486752
0.726149 0.782705 0.83527 0.882497 0.666977 0.606531 0.546074
0.666977 0.726149 0.782705 0.83527 0.726149 0.666977 0.606531
0.606531 0.666977 0.726149 0.782705 … 0.782705 0.726149 0.666977
0.546074 0.606531 0.666977 0.726149 0.83527 0.782705 0.726149
0.486752 0.546074 0.606531 0.666977 0.882497 0.83527 0.782705
0.429557 0.486752 0.546074 0.606531 0.923116 0.882497 0.83527
0.375311 0.429557 0.486752 0.546074 0.955997 0.923116 0.882497
0.324652 0.375311 0.429557 0.486752 … 0.980199 0.955997 0.923116
0.278037 0.324652 0.375311 0.429557 0.995012 0.980199 0.955997
0.235746 0.278037 0.324652 0.375311 1.0 0.995012 0.980199
0.197899 0.235746 0.278037 0.324652 0.995012 1.0 0.995012
0.164474 0.197899 0.235746 0.278037 0.980199 0.995012 1.0
Python:
>>> cov
array([[1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681,
0.60653066, 0.54607443, 0.48675226, 0.42955736, 0.3753111 ,
0.32465247, 0.2780373 , 0.23574608, 0.1978987 , 0.16447446],
[0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748,
0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904,
0.66697681, 0.60653066, 0.54607443, 0.48675226, 0.42955736,
0.3753111 , 0.32465247, 0.2780373 , 0.23574608, 0.1978987 ],
[0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867,
0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454,
0.72614904, 0.66697681, 0.60653066, 0.54607443, 0.48675226,
0.42955736, 0.3753111 , 0.32465247, 0.2780373 , 0.23574608],
[0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248,
0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021,
0.78270454, 0.72614904, 0.66697681, 0.60653066, 0.54607443,
0.48675226, 0.42955736, 0.3753111 , 0.32465247, 0.2780373 ],
[0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ,
0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ,
0.83527021, 0.78270454, 0.72614904, 0.66697681, 0.60653066,
0.54607443, 0.48675226, 0.42955736, 0.3753111 , 0.32465247],
[0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681,
0.60653066, 0.54607443, 0.48675226, 0.42955736, 0.3753111 ],
[0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748,
0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904,
0.66697681, 0.60653066, 0.54607443, 0.48675226, 0.42955736],
[0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867,
0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454,
0.72614904, 0.66697681, 0.60653066, 0.54607443, 0.48675226],
[0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248,
0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021,
0.78270454, 0.72614904, 0.66697681, 0.60653066, 0.54607443],
[0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ,
0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ,
0.83527021, 0.78270454, 0.72614904, 0.66697681, 0.60653066],
[0.60653066, 0.66697681, 0.72614904, 0.78270454, 0.83527021,
0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681],
[0.54607443, 0.60653066, 0.66697681, 0.72614904, 0.78270454,
0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748,
0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904],
[0.48675226, 0.54607443, 0.60653066, 0.66697681, 0.72614904,
0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867,
0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454],
[0.42955736, 0.48675226, 0.54607443, 0.60653066, 0.66697681,
0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248,
0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021],
[0.3753111 , 0.42955736, 0.48675226, 0.54607443, 0.60653066,
0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ,
0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ],
[0.32465247, 0.3753111 , 0.42955736, 0.48675226, 0.54607443,
0.60653066, 0.66697681, 0.72614904, 0.78270454, 0.83527021,
0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635],
[0.2780373 , 0.32465247, 0.3753111 , 0.42955736, 0.48675226,
0.54607443, 0.60653066, 0.66697681, 0.72614904, 0.78270454,
0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748],
[0.23574608, 0.2780373 , 0.32465247, 0.3753111 , 0.42955736,
0.48675226, 0.54607443, 0.60653066, 0.66697681, 0.72614904,
0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867],
[0.1978987 , 0.23574608, 0.2780373 , 0.32465247, 0.3753111 ,
0.42955736, 0.48675226, 0.54607443, 0.60653066, 0.66697681,
0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248],
[0.16447446, 0.1978987 , 0.23574608, 0.2780373 , 0.32465247,
0.3753111 , 0.42955736, 0.48675226, 0.54607443, 0.60653066,
0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ]])
Numpy 似乎毫无怨言地接受了它,但 MvNormal 说:
PosDefException: matrix is not Hermitian; Cholesky factorization failed.
或者
PosDefException: matrix is not positive definite; Cholesky factorization failed.
似乎它可能是浮点精度的问题,我尝试使用sol2:
σ = σ + maximum([0.0, -minimum(eigvals(σ))])*I
D = MvNormal(μ, σ)
这应该使矩阵正定,但没有成功。
σ = Symmetric(σ)
D = MvNormal(σ.data)
我还尝试在sol3中在需要时使其成为 Hermitian:
D = MvNormal(Matrix(Hermitian(σ)))
您对我应该做什么有任何见解吗?
经过进一步测试,似乎
σ = σ + 0.00000000001*I
D = MvNormal(σ)
工作,但我发现这个解决方案真的很糟糕,为什么它甚至工作?
解决方案
推荐阅读
- javascript - 我们如何在 React JS 中将多个 JSON 文件的值呈现为表格格式?
- javascript - 从 Twitter API V1.1 切换到 Twitter API V2 以将推文从个人资料发送到 Google 表格
- c# - C# 引用列表中的特定项目
- python - 并排直方图:更改轴
- selenium - 是否可以使用 appium 扫描二维码?
- reactjs - NextJS + SSR + IIS 部署
- sql - 如何删除或重命名数据库?
- design-patterns - 应该如何为两个以上的层次结构实现“桥”设计模式?
- angular - 在数组中取一个随机数,每次浏览器重新加载它都会改变
- laravel-8 - 注销并再次登录后出现 419 错误