python - 将 SciPy 优化应用于拟合的 sci-kit 模型
问题描述
最近我遇到了一个问题,我认为 SciPy 可能是一个很好的解决方案。但是,我无法正确应用它。不确定我是否遗漏了某些东西,或者我正在寻找的东西实际上根本不可能。
这是一个虚构的例子,我制作它是为了让事情更清晰,更容易可视化。我的情况要复杂得多。
from sklearn.model_selection import train_test_split
from sklearn.svm import SVR
from scipy.optimize import minimize
import numpy as np
import pandas as pd
time_studied = [12, 10, 4, 7, 6, 11, 6]
hours_slept = [8, 7, 1, 3, 8, 6, 5]
grade = [10, 9, 2, 5, 7, 8, 8.5, 6]
X = np.array([time_studied, hours_slept]).T
y = np.array([grade]).T
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
model = SVR(kernel='poly', C=100, gamma='auto', degree=3, epsilon=.1, coef0=1)
model.fit(X,y)
然后,我试图对该函数应用优化器,以找到睡眠和学习之间的最佳平衡。由于经过测试的回归方法返回一个函数,所以我想可以应用 SciPy 最小化。但是,当我尝试应用它时,像这样
bnds = [(0,12), (0,8)]
x0 = [0,0]
residual_plant = minimize(model, x0, method='SLSQP',bounds=bnds,options = {'eps': np.radians(5.0)})
我收到以下错误
TypeError: 'SVR' object is not callable
因此,显然可以直接从我的模型中调用优化器。因此,我的问题来了,如何访问适合我的数据的函数并能够找到最佳睡眠时间+学习时间 x 等级(在这种情况下,很明显是预期的结果)
我错过了什么吗?有可能做我的目标吗?
解决方案
尝试这个:
residual_plant = minimize(lambda x: model.predict(np.array([x])), x0, method='SLSQP',bounds=bnds,options = {'eps': np.radians(5.0)})
minimize
SciPy的第一个参数不仅仅是model.predict
因为 SciPy 试图将一维数组传递给它的目标函数,而是model.predict
期望一个二维数组。
(顺便说一句,在您的虚构模型的训练设置中,y
第二列是 ,X
并且该列表grades
从未使用过。我怀疑y
应该是np.array([grades]).T
。因为那不是您的真实模型,所以这可能并不重要。)
该方法的参考文档predict
:https ://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html?highlight=svr#sklearn.svm.SVR.predict
predict
可以在此处找到该方法的示例用法: https ://scikit-learn.org/stable/auto_examples/svm/plot_svm_regression.html#sphx-glr-auto-examples-svm-plot-svm-regression-py
推荐阅读
- java - 在服务层发生验证时测试 Web 层 - Spring Boot 测试
- c# - 使用 Npgsql 和 EntityFramework Core 针对 PostGIS 将几何转换为地理
- angular - 自定义和默认标题渲染器的 Ag 网格 Angular
- xcode - Xcode 12.3 一直冻结
- r - 具有相同/共同级别的动态过滤闪亮的应用程序
- sql-server - 更新存在的地方,引用正在更新的表?
- c - WinAPI,TreeView:TVIS_CUT 的灰色项目
- webpack - 在 Rails 6 中使用 bulma-extensions
- jestjs - 即使使用 jest.retryTimes,如何让 jest 测试第一次失败?
- php - 仅在 HTTP 上更改每个请求的 PHP 会话。Cookie 未设置