首页 > 解决方案 > `fitsvm() 为参数 'c' 获取了多个值` 错误

问题描述

我的数据集超过2000 rows23 columns包括age column. 我已经完成了SVR. 早些时候,我使用了SVR具有默认值的模型。因此,我无法为r_scored errors. 所以,现在我需要修改我的代码以找到输出最高 R_squared 值的最佳参数组合。为了获得最佳结果,我使用以下参数值进行搜索。

c = [0.01, 0.1, 1, 10, 100]
Gamma = [0.001, 0.01, 0.1, 1]
Epsilon = [0.01, 0.1, 1]

但是我在以下两行中遇到错误:

parameter['r_squared'] = fitsvm(X_train, y_train, X_test, y_test, c = c[0], gamma = Gamma[0], epsilon = Epsilon[0], axis = 1)

parameter.sort_values('r_squared',ascending=False).head()

错误是:

fitsvm() got multiple values for argument 'c'

代码:

import pandas as pd
import numpy as np

# Make fake dataset
dataset = pd.DataFrame(data= np.random.rand(2000,22))
dataset['age'] = np.random.randint(2, size=2000)

# Separate the target from the other features
target = dataset['age']
data = dataset.drop('age', axis = 1)

# train_data, train_target = data.loc[:1000], target.loc[:1000] - alternate naming scheme
X_train, y_train = data.loc[:1000], target.loc[:1000]

# test_data, test_target = data.loc[1001], target.loc[1001] - alternate naming scheme
X_test,  y_test  = data.loc[1001], target.loc[1001] 

X_test = np.array(X_test).reshape(1, -1)

print(X_test.shape)


c = [0.01, 0.1, 1, 10, 100]
Gamma = [0.001, 0.01, 0.1, 1]
Epsilon = [0.01, 0.1, 1]

c_,Gamma_,Ep_ = np.meshgrid(c,Gamma,Epsilon)
parameter = pd.DataFrame({'c':c_.flatten(),'Gamma':Gamma_.flatten(),'Epsilon':Ep_.flatten()})



def fitsvm(c,gamma,epsilon,X_train, y_train, X_test, y_test):
    SupportVectorRefModel = SVR(C=c,gamma=gamma,epsilon=epsilon)
    SupportVectorRefModel.fit(X_train, y_train)
    R_Sqr = SupportVectorRefModel.score(X_test,y_test) 
    return R_Sqr

np.random.seed
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_boston

s = StandardScaler()
dataset, y = load_boston(return_X_y=True) 

X = s.fit(dataset).transform(dataset)

X_train, y_train = X[:100], y[:100]
X_test, y_test = X[1001:], y[1001:] 

parameter['r_squared'] = fitsvm(X_train, y_train, X_test, y_test, c=X[0], gamma = X[1], epsilon = X[2] , axis=1)

parameter.sort_values('r_squared',ascending=False).head()

标签: pythonpython-3.xscikit-learnsvm

解决方案


推荐阅读