python - 如何使用 scipy 获得非平滑的 2D 样条插值
问题描述
我想要一个适合一些不规则间隔数据的二维三次样条 - 即一个完全适合给定点的数据的函数 - 但也可以返回两者之间的值。
我能找到的(对于不规则间隔的数据)是scipy.interpolate.SmoothBivariateSpline
. 我不知道如何关闭“平滑”(无论我在s
参数中输入什么值。
然而,我确实发现我可以得到大部分我想要的东西scipy.interpolate.griddata
——尽管每次都必须重新计算它(即不只是生成一个函数)。从根本上讲,这两者之间有什么区别-即griddata
做的事情与“样条曲线”不同吗?反正有没有关闭平滑SmoothBivariateSpline
或不平滑的等效功能?
以下是我用来测试样条与多项式拟合的脚本
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import scipy.interpolate
import matplotlib.pyplot as plt
import numpy.polynomial.polynomial as poly
# Grid and test function
N = 9;
x,y = np.linspace(-1,1, N), np.linspace(-1,1, N)
X,Y = np.meshgrid(x,y)
F = lambda X,Y : X+Y-1*X*Y-(X*Y)**2 -2*X*Y**2 + X**2*Y + 3*np.exp(-((X+1)**2+(Y+1)**2)*5)
Z = F(X,Y)
noise = 0.4
Z *= 1+(np.random.random(Z.shape)*2-1)*noise # noise
# Finer Grid and test function
N2 = 19;
x2,y2 = np.linspace(-1,1, N2), np.linspace(-1,1, N2)
X2,Y2 = np.meshgrid(x2,y2)
Z2 = F(X2,Y2)
# Make data into lists
Xl = X.reshape(X.size)
Yl = Y.reshape(Y.size)
Zl = Z.reshape(Z.size)
# Polynomial fit
# polyval(x,y,p) = p[0,0]+p[0,1]y+p[1,0]x+p[1,1]xy+p[1,2]xy^2 ..., etc
# I use a flat (1D) array for p, so it needs to be reshaped into a 2D array before
# passing to polyval
order = 3
p0 = np.zeros(order**2) # guess parameters (all 0 for now)
f_poly = lambda x,y,p : poly.polyval2d(x,y,p.reshape((order,order))) # Wrapper for our polynomial
errf = lambda p : np.mean((f_poly(Xl,Yl,p.reshape((order,order)))-Zl)**2) # error function to find least square error
sol = scipy.optimize.minimize(errf, p0)
psol = sol['x']
# Spline interpolation
# Bivariate (2D), Smoothed (doesn't fit points *exactly*) cubic (3rd order - i.e. kx=ky=3) spline
spl = scipy.interpolate.SmoothBivariateSpline(Xl, Yl, Zl, kx=3,ky=3)
f_spline = spl.ev
# regular Interpolate
f_interp = lambda x,y : scipy.interpolate.griddata((Xl, Yl), Zl, (x,y), method='cubic')
# Plot
fig = plt.figure(1, figsize=(7,8))
plt.clf()
# poly fit
ax = fig.add_subplot(311, projection='3d')
ax.scatter3D(X2,Y2,Z2,s=3, color='red', label='actual data')
fit = f_poly(X2,Y2, psol)
l = 'order {} poly fit'.format(order)
ax.plot_wireframe(X2,Y2, fit, color='black', label=l)
ax.scatter3D(X,Y,Z, color='blue', label='noisy data')
plt.legend()
print("Average {} error: {}".format(l, np.sqrt(np.mean((fit-Z2)**2))))
# spline fit
ax = fig.add_subplot(312, projection='3d')
ax.scatter3D(X2,Y2,Z2,s=3, color='red', label='actual data')
l = 'smoothed spline'
fit = f_spline(X2,Y2)
ax.plot_wireframe(X2,Y2, fit, color='black', label=l)
ax.scatter3D(X,Y,Z, color='blue', label='noisy data')
plt.legend()
print("Average {} error: {}".format(l, np.sqrt(np.mean((fit-Z2)**2))))
# interp fit
ax = fig.add_subplot(313, projection='3d')
ax.scatter3D(X2,Y2,Z2,s=3, color='red', label='actual data')
l='3rd order interp '
fit=f_interp(X2,Y2)
ax.plot_wireframe(X2,Y2, fit, color='black', label=l)
ax.scatter3D(X,Y,Z, color='blue', label='noisy data')
plt.legend()
print("Average {} error: {}".format(l, np.sqrt(np.mean((fit-Z2)**2))))
plt.show(False)
plt.pause(1)
raw_input('press key to continue') # Change to input() if using python3
解决方案
对于非结构化网格,griddata
是正确的插值工具。但是,每次都会执行三角剖分(Delaunay)和插值。一种解决方法是使用CloughTocher2DInterpolator
C1 平滑插值或LinearNDInterpolator
线性插值。这些是实际使用的功能griddata
。不同之处在于它可以用作输入 aDelaunay object
并返回一个插值函数。
这是基于您的代码的示例:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from scipy.interpolate import CloughTocher2DInterpolator
from scipy.spatial import Delaunay
# Example unstructured mesh:
nodes = np.array([[-1. , -1. ],
[ 1. , -1. ],
[ 1. , 1. ],
[-1. , 1. ],
[ 0. , 0. ],
[-1. , 0. ],
[ 0. , -1. ],
[-0.5 , 0. ],
[ 0. , 1. ],
[-0.75 , 0.4 ],
[-0.5 , 1. ],
[-1. , -0.6 ],
[-0.25 , -0.5 ],
[-0.5 , -1. ],
[-0.20833333, 0.5 ],
[ 1. , 0. ],
[ 0.5 , 1. ],
[ 0.36174242, 0.44412879],
[ 0.5 , -0.03786566],
[ 0.2927264 , -0.5411368 ],
[ 0.5 , -1. ],
[ 1. , 0.5 ],
[ 1. , -0.5 ]])
# Theoretical function:
def F(x, y):
return x + y - x*y - (x*y)**2 - 2*x*y**2 + x**2*y + 3*np.exp( -((x+1)**2 + (y+1)**2)*5 )
z = F(nodes[:, 0], nodes[:, 1])
# Finer regular grid:
N2 = 19
x2, y2 = np.linspace(-1, 1, N2), np.linspace(-1, 1, N2)
X2, Y2 = np.meshgrid(x2, y2)
# Interpolation:
tri = Delaunay(nodes)
CT_interpolator = CloughTocher2DInterpolator(tri, z)
z_interpolated = CT_interpolator(X2, Y2)
# Plot
fig = plt.figure(1, figsize=(8,14))
ax = fig.add_subplot(311, projection='3d')
ax.scatter3D(nodes[:, 0], nodes[:, 1], z, s=15, color='red', label='points')
ax.plot_wireframe(X2, Y2, z_interpolated, color='black', label='interpolated')
plt.legend();
得到的图是:
样条法和 Clough-Tocher 插值均基于在网格元素上构造分段多项式函数。不同之处在于,对于样条,网格是规则的,并且由算法给出(参见 参考资料.get_knots()
)。并且对系数进行拟合,使函数尽可能接近点并且平滑(拟合)。对于 Clough-Tocher 插值,网格元素是作为输入给出的元素。因此,可以保证生成的函数通过这些点。
推荐阅读
- angular - 我如何延迟加载组件 html 直到获得所有图像?
- vue.js - Vue axios 发布请求被 CORS 策略阻止
- powershell - 删除 PowerShell 中特定行之后的所有行
- java - 什么会导致文本在 oracle 的移动应用程序框架中的 amx 命令按钮中不显示?
- java - API 调用中的 HTML 实体编码不正确
- c++ - 有没有办法在提供左值和右值重载时删除重复代码?
- c# - 返回不杀死递归函数
- tomcat - Gulp 项目部署到 Tomcat 服务器
- numpy - Numpy 更改矩阵维度
- d3.js - D3,js 路径数据字符串 - 我应该解析它还是有一个 API?