python - 使用 scipy.optimize.fmin_cobyla 得到错误的结果
问题描述
我对 scipy 很陌生,现在我in scipy.optimize
通过做一些小实验来努力使用函数。
我试图通过找到具有最低误差值的参数来拟合 sin 函数。
使用的功能是fmin_cobyla
代码如下:
import matplotlib.pyplot as plt
from scipy.optimize import fmin_cobyla
from scipy.optimize import fmin_slsqp
from scipy.optimize import leastsq
import numpy as np
from sympy import *
noise = np.random.randn(100)
def func_model(x, para):
''' Model: y = a*sin(2*k*pi*x+theta)'''
a, k, theta = para
return a*np.sin(2*k*np.pi*x+theta)
def func_noise(x, para):
a, k, theta = para
return a*np.sin(2*k*np.pi*x+theta) + noise
def func_error(para_guess):
'''error_func'''
x_seq = np.linspace(-2*np.pi, 0, 100)
para_fact = [10, 0.34, np.pi/6]
data = func_noise(x_seq, para_fact)
error_value = data - func_model(x_seq, para_guess)
return error_value
# 1<a<15 0<k<1 0<theta<pi/2
constraints = [lambda x: 15 - x[0], lambda x: x[0]- 1, \
lambda x: 1 - x[1], lambda x: x[1], \
lambda x: np.pi/2 - x[2], lambda x: x[2]]
para_guess_init = np.array([7, 0.2, 0])
solution = fmin_cobyla(func_error, para_guess_init, constraints)
print(solution) # supposed to be like [10, 0.34, np.pi/6]
xx = np.linspace(-2*np.pi, 0, 100)
plt.plot(xx, func_model(xx, [10, 0.34, np.pi/6]), label="raw")
plt.plot(xx, func_noise(xx, [10, 0.34, np.pi/6]), label="with noise")
plt.plot(xx, func_model(xx, solution), label="fitted")
plt.legend()
plt.show()
运行后我得到了结果
解决方案 = [1.6655938 0.59868667 0.0731335]
这肯定不是正确答案
有人可以帮助我。提前致谢..
解决方案
这里有两件事显然是错误的:首先,每次调用目标函数时都会更改噪声,因此您的优化是试图击中移动目标。在调用之前设置模拟数据fmin_cobyla
:
the_noise = np.random.randn(100)
data = func_noise(x_seq, para_fact)
此外,您func_error
应该返回模型和每个点的数据之间的差异,而不是平方和差异:
def func_error(para_guess):
error_value = data - func_model(x_seq, para_guess)
return error_value
您仍然可能会发现fmin_cobyla
很难找到受约束的最小值......一些预处理以更好地估计相位或频率的初始猜测可能会帮助您。
推荐阅读
- php - Laravel 的干预 - JPG 提供的不支持的图像类型
- exception - 在 wso2 上抛出异常
- twitter-bootstrap - 当 .card-body 内容溢出时,在 .card-header 上同时设置 height 和 d-flex 不起作用
- rust - 如何对包含 &mut 枚举的元组进行模式匹配并在匹配臂中使用该枚举?
- laravel-5 - 无法扩展 Laravel 模型
- spring-boot - Spring boot aspectj gradle编译时编织问题
- javascript - 如何在 vh 和 vw 中而不是在像素中获得位置?
- amazon-web-services - 如何选择通过 Amazon Pinpoint 重新接收 SMS 消息
- javascript - 我应该在我的函数中指定值类型吗?JavaScript/JQuery
- postgresql - 在 postgres 中使用选定列中的值添加时间间隔