首页 > 解决方案 > Distributions.jl 不接受我的(非正定)协方差函数(而 numpy 接受)

问题描述

我正在尝试在krasserm之后在 Julia 中实现高斯过程(GP) 。

numpy 的实现就像一个魅力:

import numpy as np

def kernel(X1, X2, l=1.0, sigma_f=1.0):
    ''' Isotropic squared exponential kernel. Computes a covariance matrix from points in X1 and X2. Args: X1: Array of m points (m x d). X2: Array of n points (n x d). Returns: Covariance matrix (m x n). '''
    sqdist = np.sum(X1**2, 1).reshape(-1, 1) + np.sum(X2**2, 1) - 2 * np.dot(X1, X2.T)
    return sigma_f**2 * np.exp(-0.5 / l**2 * sqdist)

X = np.arange(-1, 1, 0.1).reshape(-1, 1)

mu = np.zeros(X.shape)
cov = kernel(X, X)

samples = np.random.multivariate_normal(mu.ravel(), cov, 3)

虽然我在使用 Distributions.jl 创建多元正态时出错。

这是我尝试过的:

using LinearAlgebra
using Distributions

function kernel(x₁, x₂; l=1.0, σ=1.0)
    t₁ = sum(x₁.^2, dims=2)
    t₂ = sum(x₂.^2, dims=2)
    t₃ = 2*x₁*x₂'
    t = (t₁ .+ t₂') - t₃
    return σ^2 * exp.(-0.5/l^2 * t)
end

x = -1:0.1:0.9
μ = zeros(length(x))
σ = kernel(x, x)
D = MvNormal(μ, σ)

问题似乎是我的协方差不是正(半)确定的。

朱莉娅:

σ
20×20 Array{Float64,2}:
 1.0       0.995012  0.980199  0.955997  …  0.235746  0.197899  0.164474
 0.995012  1.0       0.995012  0.980199     0.278037  0.235746  0.197899
 0.980199  0.995012  1.0       0.995012     0.324652  0.278037  0.235746
 0.955997  0.980199  0.995012  1.0          0.375311  0.324652  0.278037
 0.923116  0.955997  0.980199  0.995012     0.429557  0.375311  0.324652
 0.882497  0.923116  0.955997  0.980199  …  0.486752  0.429557  0.375311
 0.83527   0.882497  0.923116  0.955997     0.546074  0.486752  0.429557
 0.782705  0.83527   0.882497  0.923116     0.606531  0.546074  0.486752
 0.726149  0.782705  0.83527   0.882497     0.666977  0.606531  0.546074
 0.666977  0.726149  0.782705  0.83527      0.726149  0.666977  0.606531
 0.606531  0.666977  0.726149  0.782705  …  0.782705  0.726149  0.666977
 0.546074  0.606531  0.666977  0.726149     0.83527   0.782705  0.726149
 0.486752  0.546074  0.606531  0.666977     0.882497  0.83527   0.782705
 0.429557  0.486752  0.546074  0.606531     0.923116  0.882497  0.83527 
 0.375311  0.429557  0.486752  0.546074     0.955997  0.923116  0.882497
 0.324652  0.375311  0.429557  0.486752  …  0.980199  0.955997  0.923116
 0.278037  0.324652  0.375311  0.429557     0.995012  0.980199  0.955997
 0.235746  0.278037  0.324652  0.375311     1.0       0.995012  0.980199
 0.197899  0.235746  0.278037  0.324652     0.995012  1.0       0.995012
 0.164474  0.197899  0.235746  0.278037     0.980199  0.995012  1.0 

Python:

>>> cov
array([[1.        , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
        0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681,
        0.60653066, 0.54607443, 0.48675226, 0.42955736, 0.3753111 ,
        0.32465247, 0.2780373 , 0.23574608, 0.1978987 , 0.16447446],
       [0.99501248, 1.        , 0.99501248, 0.98019867, 0.95599748,
        0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904,
        0.66697681, 0.60653066, 0.54607443, 0.48675226, 0.42955736,
        0.3753111 , 0.32465247, 0.2780373 , 0.23574608, 0.1978987 ],
       [0.98019867, 0.99501248, 1.        , 0.99501248, 0.98019867,
        0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454,
        0.72614904, 0.66697681, 0.60653066, 0.54607443, 0.48675226,
        0.42955736, 0.3753111 , 0.32465247, 0.2780373 , 0.23574608],
       [0.95599748, 0.98019867, 0.99501248, 1.        , 0.99501248,
        0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021,
        0.78270454, 0.72614904, 0.66697681, 0.60653066, 0.54607443,
        0.48675226, 0.42955736, 0.3753111 , 0.32465247, 0.2780373 ],
       [0.92311635, 0.95599748, 0.98019867, 0.99501248, 1.        ,
        0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ,
        0.83527021, 0.78270454, 0.72614904, 0.66697681, 0.60653066,
        0.54607443, 0.48675226, 0.42955736, 0.3753111 , 0.32465247],
       [0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
        1.        , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
        0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681,
        0.60653066, 0.54607443, 0.48675226, 0.42955736, 0.3753111 ],
       [0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
        0.99501248, 1.        , 0.99501248, 0.98019867, 0.95599748,
        0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904,
        0.66697681, 0.60653066, 0.54607443, 0.48675226, 0.42955736],
       [0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
        0.98019867, 0.99501248, 1.        , 0.99501248, 0.98019867,
        0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454,
        0.72614904, 0.66697681, 0.60653066, 0.54607443, 0.48675226],
       [0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
        0.95599748, 0.98019867, 0.99501248, 1.        , 0.99501248,
        0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021,
        0.78270454, 0.72614904, 0.66697681, 0.60653066, 0.54607443],
       [0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
        0.92311635, 0.95599748, 0.98019867, 0.99501248, 1.        ,
        0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ,
        0.83527021, 0.78270454, 0.72614904, 0.66697681, 0.60653066],
       [0.60653066, 0.66697681, 0.72614904, 0.78270454, 0.83527021,
        0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
        1.        , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
        0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681],
       [0.54607443, 0.60653066, 0.66697681, 0.72614904, 0.78270454,
        0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
        0.99501248, 1.        , 0.99501248, 0.98019867, 0.95599748,
        0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904],
       [0.48675226, 0.54607443, 0.60653066, 0.66697681, 0.72614904,
        0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
        0.98019867, 0.99501248, 1.        , 0.99501248, 0.98019867,
        0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454],
       [0.42955736, 0.48675226, 0.54607443, 0.60653066, 0.66697681,
        0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
        0.95599748, 0.98019867, 0.99501248, 1.        , 0.99501248,
        0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021],
       [0.3753111 , 0.42955736, 0.48675226, 0.54607443, 0.60653066,
        0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
        0.92311635, 0.95599748, 0.98019867, 0.99501248, 1.        ,
        0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ],
       [0.32465247, 0.3753111 , 0.42955736, 0.48675226, 0.54607443,
        0.60653066, 0.66697681, 0.72614904, 0.78270454, 0.83527021,
        0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
        1.        , 0.99501248, 0.98019867, 0.95599748, 0.92311635],
       [0.2780373 , 0.32465247, 0.3753111 , 0.42955736, 0.48675226,
        0.54607443, 0.60653066, 0.66697681, 0.72614904, 0.78270454,
        0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
        0.99501248, 1.        , 0.99501248, 0.98019867, 0.95599748],
       [0.23574608, 0.2780373 , 0.32465247, 0.3753111 , 0.42955736,
        0.48675226, 0.54607443, 0.60653066, 0.66697681, 0.72614904,
        0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
        0.98019867, 0.99501248, 1.        , 0.99501248, 0.98019867],
       [0.1978987 , 0.23574608, 0.2780373 , 0.32465247, 0.3753111 ,
        0.42955736, 0.48675226, 0.54607443, 0.60653066, 0.66697681,
        0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
        0.95599748, 0.98019867, 0.99501248, 1.        , 0.99501248],
       [0.16447446, 0.1978987 , 0.23574608, 0.2780373 , 0.32465247,
        0.3753111 , 0.42955736, 0.48675226, 0.54607443, 0.60653066,
        0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
        0.92311635, 0.95599748, 0.98019867, 0.99501248, 1.        ]])

Numpy 似乎毫无怨言地接受了它,但 MvNormal 说:

PosDefException: matrix is not Hermitian; Cholesky factorization failed.

或者

PosDefException: matrix is not positive definite; Cholesky factorization failed.

似乎它可能是浮点精度的问题,我尝试使用sol2

σ = σ + maximum([0.0, -minimum(eigvals(σ))])*I
D = MvNormal(μ, σ)

这应该使矩阵正定,但没有成功。

解决方案sol1sol1sol2的组合也不起作用:

σ = Symmetric(σ)
D = MvNormal(σ.data)

我还尝试在sol3中在需要时使其成为 Hermitian:

D = MvNormal(Matrix(Hermitian(σ)))

您对我应该做什么有任何见解吗?


经过进一步测试,似乎

σ = σ + 0.00000000001*I
D = MvNormal(σ)

工作,但我发现这个解决方案真的很糟糕,为什么它甚至工作?

标签: numpyjulialinear-algebranormal-distribution

解决方案


推荐阅读