python - 具有峰值和平顶(超)高斯信号的 Python 曲线拟合问题
问题描述
我有一个标准的高斯函数,如下所示:
def gauss_fnc(x, amp, cen, sigma):
return amp * np.exp(-(x - cen) ** 2 / (2 * sigma ** 2))
我有一个 fit_gaussian 函数,它使用 scipy 的 curve_fit 来拟合我的 gauss_fnc:
from scipy.optimize import curve_fit
def fit_gaussian(x, y):
mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean) ** 2) / sum(y))
opt, cov = curve_fit(gauss_fnc, x, y, p0=[max(y), mean, sigma])
values = gauss_fnc(x, *opt)
return values, sigma, opt, cov
如果数据类似于正常的高斯函数,我可以确认这很有效,请参见示例:
但是,如果信号太尖或太窄,它就不会按预期工作。峰值高斯的示例:
这是平顶或超高斯的示例:
目前,高斯变得越平坦,由于高斯切割边缘,越来越多的信息丢失。如何改进函数或曲线拟合,以便能够像这张图片一样拟合峰值和平顶信号:
编辑:
我提供了一个最小的工作示例来试试这个:
from PyQt5.QtWidgets import (QApplication, QMainWindow)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from scipy.optimize import curve_fit
import numpy as np
from PyQt5.QtWidgets import QWidget, QGridLayout
def gauss_fnc(x, amp, cen, sigma):
return amp * np.exp(-(x - cen) ** 2 / (2 * sigma ** 2))
def fit_gauss(x, y):
mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean) ** 2) / sum(y))
opt, cov = curve_fit(gauss_fnc, x, y, p0=[max(y), mean, sigma])
vals = gauss_fnc(x, *opt)
return vals, sigma, opt, cov
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.results = list()
self.setWindowTitle('Gauss fitting')
self.setGeometry(50, 50, 1280, 1024)
self.setupLayout()
self.raw_data1 = np.array([1, 1, 1, 1, 3, 5, 7, 8, 9, 10, 11, 10, 9, 8, 7, 5, 3, 1, 1, 1, 1], dtype=int)
self.raw_data2 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
self.raw_data3 = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1, 1], dtype=int)
self.plot()
def setupLayout(self):
# Create figures
self.fig1 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig1AX = self.fig1.figure.add_subplot(111, frameon=False)
self.fig1AX.get_xaxis().set_visible(True)
self.fig1AX.get_yaxis().set_visible(True)
self.fig1AX.yaxis.tick_right()
self.fig1AX.yaxis.set_label_position("right")
self.fig2 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig2AX = self.fig2.figure.add_subplot(111, frameon=False)
self.fig2AX.get_xaxis().set_visible(True)
self.fig2AX.get_yaxis().set_visible(True)
self.fig2AX.yaxis.tick_right()
self.fig2AX.yaxis.set_label_position("right")
self.fig3 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig3AX = self.fig3.figure.add_subplot(111, frameon=False)
self.fig3AX.get_xaxis().set_visible(True)
self.fig3AX.get_yaxis().set_visible(True)
self.fig3AX.yaxis.tick_right()
self.fig3AX.yaxis.set_label_position("right")
self.widget = QWidget(self)
grid = QGridLayout()
grid.addWidget(self.fig1, 0, 0, 1, 1)
grid.addWidget(self.fig2, 1, 0, 1, 1)
grid.addWidget(self.fig3, 2, 0, 1, 1)
self.widget.setLayout(grid)
self.setCentralWidget(self.widget)
def plot(self):
x = len(self.raw_data1)
xvals, sigma, optw, covar = fit_gauss(range(x), self.raw_data1)
self.fig1AX.clear()
self.fig1AX.plot(range(len(self.raw_data1)), self.raw_data1, 'k-')
self.fig1AX.plot(range(len(self.raw_data1)), xvals, 'b-', linewidth=2)
self.fig1AX.margins(0, 0)
self.fig1.figure.tight_layout()
self.fig1.draw()
xvals, sigma, optw, covar = fit_gauss(range(x), self.raw_data1)
self.fig2AX.clear()
self.fig2AX.plot(range(len(self.raw_data2)), self.raw_data2, 'k-')
self.fig2AX.plot(range(len(self.raw_data2)), xvals, 'b-', linewidth=2)
self.fig2AX.margins(0, 0)
self.fig2.figure.tight_layout()
self.fig2.draw()
self.fig3AX.clear()
self.fig3AX.plot(range(len(self.raw_data3)), self.raw_data3, 'k-')
self.fig3AX.plot(range(len(self.raw_data3)), xvals, 'b-', linewidth=2)
self.fig3AX.margins(0, 0)
self.fig3.figure.tight_layout()
self.fig3.draw()
if __name__ == '__main__':
app = QApplication([])
window = MainWindow()
window.show()
app.exec_()
最后一张图片来自这里。
解决方案
您可以使用高斯定义函数进行曲线拟合:
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
x = range(21)
y_peak = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
y_flat_top = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1, 1], dtype=int)
# define gauss function
def Gauss(x, a, x0, sigma):
return a * np.exp(-(x - x0)**2 / (2 * sigma**2))
# fit function
popt, pcov = curve_fit(Gauss, x, y_peak)
# set data for curve plot
x_fit = np.linspace(0,21,1000)
y_fit = Gauss(x_fit, popt[0], popt[1], popt[2])
y_fit = Gauss(x_fit, max(y_flat_top) , popt[1], popt[2])
# plot data
fig, ax = plt.subplots()
plt.plot(x, y_peak, '.')
plt.plot(x_fit, y_fit, '-', label='peak')
plt.legend()
plt.show()
输出:
使用广义正态分布:很难拟合。您可以使用边界并尝试添加一些额外的参数来获得更好的拟合效果。另一种选择是使用差分进化算法来找到最佳拟合。
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
# set data
x = np.linspace(-4, 4, 20)
y_peak = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
y_flat_top = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1], dtype=int)
y = y_peak
# define generalized normal distribution
def general_norm(x, gamma, beta):
value = (beta/(2*gamma*(1/beta)))*np.exp(-np.abs(x)**beta)
return value
# set bounds
bounds_peak = ((0,0),(100,9))
bounds_flat_top = ((0,7),(100,9))
# fit function
popt, pcov = curve_fit(general_norm, x, y, bounds=bounds_peak)
# calculate rms
rms = sum((y - general_norm(x, popt[0], popt[1]))**2)
# set data for curve plot
x_fit = np.linspace(-4,4,1000)
y_fit = general_norm(x_fit, popt[0], popt[1])
# plot data
fig, ax = plt.subplots(1, 1)
ax.plot(x, y, '.')
ax.plot(x_fit, y_fit, 'b-')
plt.show()
输出:
推荐阅读
- python - 烧瓶表单没有验证
- python - 如何使用 Selenium TimeoutException 修复未绑定的本地错误
- javascript - 在 ReactJS 中使用按钮过滤数据
- linux - i3 将编号分配给窗口并移动到编号
- python - 在 python3.6 上安装 pycrypto
- ckeditor5 - CKEditor 5 Angular 与提及插件
- excel - VBA 从工作表选择中取消选择行和列
- outlook - AdaptiveCard 无法在 Office 365 Proplus Outlook 中呈现
- mysql - 根据另外两列的范围在两列中扩展日期
- firebase - Firestore 云功能在不部署功能的情况下不会触发