首页 > 解决方案 > curve_fit TypeError: 输入不当:N=2 不得超过 M=1

问题描述

我正在尝试将正弦曲线拟合到从 .mat 文件导入的数据。

from scipy.io import loadmat
data=loadmat('TP4.mat')
x_data=data['x']
y_data=data['y']
def TEST(x,a,b):
return (a*np.sin(b*x))
params2,params_covariance2=optimize.curve_fit(TEST,x_data,y_data)

我一直在寻找一个小时的原因,但它返回了错误:

 params2,params_covariance2=optimize.curve_fit(TEST,x_data,y_data)
File "D:\python\lib\site-packages\scipy\optimize\minpack.py", line 784, in curve_fit
res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)

 File "D:\python\lib\site-packages\scipy\optimize\minpack.py", line 414, in leastsq
raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m))

TypeError: Improper input: N=2 must not exceed M=1

有人知道为什么它会继续使用 curve_fit 吗?

x_data 是一个数组 (1,50) :

[[-5.         -4.79591837 -4.59183673 -4.3877551  -4.18367347 -3.97959184
  -3.7755102  -3.57142857 -3.36734694 -3.16326531 -2.95918367 -2.75510204
  -2.55102041 -2.34693878 -2.14285714 -1.93877551 -1.73469388 -1.53061224
  -1.32653061 -1.12244898 -0.91836735 -0.71428571 -0.51020408 -0.30612245
  -0.10204082  0.10204082  0.30612245  0.51020408  0.71428571  0.91836735
   1.12244898  1.32653061  1.53061224  1.73469388  1.93877551  2.14285714
   2.34693878  2.55102041  2.75510204  2.95918367  3.16326531  3.36734694
   3.57142857  3.7755102   3.97959184  4.18367347  4.3877551   4.59183673
   4.79591837  5.        ]]

并且 y_data 也是一个数组 (1,50):

[[ 1.93850755  3.90748334  0.74775799  3.80239638  3.32772828  2.71107852
   2.85786376  2.50109072  2.62484054  0.11701355 -1.34835766 -1.03896612
  -1.82007581 -1.49086175 -2.56849896 -4.68713822 -3.15843002 -2.77910422
  -2.87940036  1.73778484  0.31566242  1.81945211  2.58367122  3.26765109
   1.79831117  3.35824518  4.25703245  0.96006812  1.24266396  0.75338212
  -0.70009445 -1.12690812 -1.09354647 -3.46942727 -3.67550846 -4.53297036
  -2.55420957 -2.87671367 -2.33624323 -1.37104854  1.36564919  1.61563281
   2.46671249  2.6525523   2.4522091   2.96971773  2.56662355  3.35746594
   2.03000529  2.81887444]]

标签: pythonmatplotlibscipyscipy-optimize

解决方案


你的数组是一个意想不到的形状,所以我们必须重塑它们。由于np.reshape()创建了不同的视图,我们必须使用np.shape()

from scipy import optimize
import numpy as np
from matplotlib import pyplot as plt

def TEST(x,a,b):
    return a*np.sin(b*x)

#random test data
np.random.seed(1234)
n=100
x_data = np.linspace(0, 10, n)
y_data = TEST(x_data, 0.2, -0.4) + np.random.normal(scale=0.01,size=n)    
#your data structure
x_data.shape = (1, n)
y_data.shape = (1, n)
#print(x_data)
#print(y_data)

#changing the shape of the data for scipy
#if you remove this block you can reproduce your error message
x_data.shape = (n,)
y_data.shape = (n,)
#print(x_data)
#print(y_data)    

params2,params_covariance2=optimize.curve_fit(TEST,x_data,y_data, p0=[0.1, 0.1])    
print(params2)

y_fit = TEST(x_data, *params2)    
plt.plot(x_data, y_data, label="values")
plt.plot(x_data, y_fit, label="fit")

plt.legend()
plt.show()

输出

[ 0.1992707  -0.40013289]

在此处输入图像描述


推荐阅读