首页 > 解决方案 > 如何在 python 曲线拟合 scipy.optimize 处修复 location=0

问题描述

以下是示例代码

x=[1.5,4,10,50,90]
y=[6/100,2.6/100,1.4/100,0.4/100,0.2/100]

def f(x, a, loc,scale):
    loc=0 
    return gamma.pdf(x, a, loc, scale)
optimize.curve_fit(f, x, y)

结果给了我一个 loc=1。有什么办法可以让 loc=0 吗?我注意到,当 x 没有整数元素时,不能将 loc 固定为 0,否则曲线拟合不起作用。我可以知道这背后的算法吗?


作为一个例子来说明为什么我的代码在某些情况下不起作用,

from scipy import optimize
from scipy.stats import gamma

def f(x, a, loc,scale):
    loc=0 
    return gamma.pdf(x, a, loc, scale)

init_guess=[0.1,0,0.1]

fig= plt.subplots(figsize=(5,3))
fit_2worst = optimize.curve_fit(f, x, y,p0=init_guess)

x2 = np.linspace (0, 100, 200)
y2 = gamma.pdf(x2, a=fit_2worst[0][0], loc=fit_2worst[0][1],scale=fit_2worst[0][2])

plt.title('Gamma with k='+str("{:.2}".format(fit_2worst[0][0]))+'\nTheta='+str(int(fit_2worst[0][2])))
plt.plot(x2, y2, "y-") 
print ('k:',fit_2worst[0][0],'location:',fit_2worst[0][1],'theta:',fit_2worst[0][2])
plt.show()

回报是

k: 36.171512499294444 location: 0.0 theta: 3.725335489050758

显示的图片是 在此处输入图像描述

使用@Joe 提出的代码,我能够得到正确的代码

def f(x, a, scale):
    #loc=0 
    return gamma.pdf(x, a, scale=scale, loc=0)

fig= plt.subplots(figsize=(5,3))
opt = optimize.curve_fit(f, x, y)


x2 = np.linspace (0, 100, 200)
y2 = gamma.pdf(x2, a=opt[0][0],scale=opt[0][1])

plt.title('Gamma with k='+str("{:.2}".format(opt[0][0]))+'\nTheta='+str(int(opt[0][1])))
plt.plot(x2, y2, "y-") 
print ('k:',opt[0][0],'location:',0,'theta:',opt[0][1])
plt.show()

有回报 k: 0.23311781831847955 location: 0 theta: 132.0300661365553

在此处输入图像描述

我不确定为什么前面的代码不适用于浮点数而是整数?

标签: pythoncurve-fitting

解决方案


这只是最小二乘。

您可以通过不使其变量来使 loc = 0 ,因此优化器不能免费使用它。尝试

def f(x, a, scale):
    #loc=0 
    return gamma.pdf(x, a, scale=scale, loc=0)

optimize.curve_fit(f, x, y)

带图:

import matplotlib.pyplot as plt
from scipy import optimize
from scipy.stats import gamma

import numpy as np
x=[1.5,4,10,50,90]
y=[6/100,2.6/100,1.4/100,0.4/100,0.2/100]

def f(x, a, scale):
    #loc=0 
    return gamma.pdf(x, a, scale=scale, loc=0)

opt = optimize.curve_fit(f, x, y)
print(opt)

x_0 = np.arange(0.0, 90)
y_0 = f(x_0, *(opt[0]))

plt.plot(x,y)
plt.plot(x_0,y_0, 'r.')
plt.show()

在此处输入图像描述


推荐阅读