首页 > 解决方案 > 如何将模型应用于以前看不见的验证集:拟合错误期间的不同形状

问题描述

我有这段代码,它通过交叉验证运行 SVR:

def run_SVR(xTrain,yTrain,xTest,yTest,output_file,data_name):
  '''
  run SVR algorithm
  '''
  short_dataname =  data_name.strip().split('/')
  file_model_name = output_file  + '_svr_' + short_dataname[-1]

  cv = RepeatedKFold(n_splits=10, n_repeats=3, random_state=1)

  # define the pipeline to evaluate
  model = SVR()
  fs = SelectKBest(score_func=mutual_info_regression)
  pipeline = Pipeline(steps=[('sel',fs), ('svr', model)])

  # define the grid
  grid = dict()
  grid['sel__k'] = [i for i in range(1, xTrain.shape[1]+1)]
  search = GridSearchCV(
        pipeline,
        param_grid={
            'svr__C': [0.01, 0.1, 1, 10, 100, 1000], ##Regularization
            'svr__epsilon': [0.0001, 0.001, 0.01, 0.1, 1, 10],
            'svr__gamma': [0.0001,  0.001, 0.01, 0.1, 1, 10],
        },
        scoring='neg_mean_squared_error',
        return_train_score=True,
        verbose=1,
        cv=5,
        n_jobs=-1)

  results = search.fit(xTrain, yTrain)

  # save the model to disk
  pickle.dump(results, open(file_model_name, 'wb'))

  # predict prices of X_test
  y_pred = results.predict(xTest)
  spearman = Get_score(y_pred, yTest)
  print(spearman)

xTrainyTrain和是我之前从 CSV 文件生成此数据的位置xTestyTest

我现在对我的模型很满意,我想将模型应用到一个全新的、以前看不见的数据集。

所以我写了这段代码:

### Load the Model back from file

with open('svr_out_svr_lbp_data_short', 'rb') as file_name:  #Reads in the model
    Pickled_LR_Model = pickle.load(file_name)
    lbp_test = pd.read_csv('lbp_test.csv',sep='\t') #Reads in the test data
    Ypredict = Pickled_LR_Model.predict(lbp_test) 

我得到的错误是:ValueError: X has a different shape than during fitting.

lbp_train.shape我可以通过打印和确认这一点lbp_test.shape

print(lbp_test.shape)
print(lbp_train.shape)

(142, 124)
(38, 124)

(我知道行数很少,我试图让一个例子工作)。

我想知道在我预测之前我是否打算做 fit_transform :

Ypredict = Pickled_LR_Model.fit_transform(lbp_test) 

为了使形状相同,但是输出是:

GridSearchCV' object has no attribute 'fit_transform'

我想这是有道理的,我只想将以前的模型应用于我的验证集,而不是重新生成模型。

有人可以向我展示如何将我训练过的模型应用到一个全新的验证数据集,没有标签,所以我可以预测吗?我觉得这种有同样的问题,但我正在努力将他们的代码应用于我自己的代码?

标签: pythonmachine-learningscikit-learnpredict

解决方案


推荐阅读