python - 使用 K 折交叉验证标准化数据
问题描述
我正在使用 StratifiedKFold 所以我的代码看起来像这样
def train_model(X,y,X_test,folds,model):
scores=[]
for fold_n, (train_index, valid_index) in enumerate(folds.split(X, y)):
X_train,X_valid = X[train_index],X[valid_index]
y_train,y_valid = y[train_index],y[valid_index]
model.fit(X_train,y_train)
y_pred_valid = model.predict(X_valid).reshape(-1,)
scores.append(roc_auc_score(y_valid, y_pred_valid))
print('CV mean score: {0:.4f}, std: {1:.4f}.'.format(np.mean(scores), np.std(scores)))
folds = StratifiedKFold(10,shuffle=True,random_state=0)
lr = LogisticRegression(class_weight='balanced',penalty='l1',C=0.1,solver='liblinear')
train_model(X_train,y_train,X_test,repeted_folds,lr)
现在在训练模型之前我想标准化数据,那么哪种方法是正确的?
1)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
在调用 train_model 函数之前执行此操作
2)
像这样在函数内部进行标准化
def train_model(X,y,X_test,folds,model):
scores=[]
for fold_n, (train_index, valid_index) in enumerate(folds.split(X, y)):
X_train,X_valid = X[train_index],X[valid_index]
y_train,y_valid = y[train_index],y[valid_index]
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_vaid = scaler.transform(X_valid)
X_test = scaler.transform(X_test)
model.fit(X_train,y_train)
y_pred_valid = model.predict(X_valid).reshape(-1,)
scores.append(roc_auc_score(y_valid, y_pred_valid))
print('CV mean score: {0:.4f}, std: {1:.4f}.'.format(np.mean(scores), np.std(scores)))
根据我在第二个选项中的知识,我没有泄漏数据。所以如果我不使用管道,哪种方式是正确的,如果我想使用交叉验证,如何使用管道?
解决方案
实际上,第二个选项更好,因为缩放器看不到X_valid
to scale的值X_train
。
现在,如果您要使用管道,您可以执行以下操作:
from sklearn.pipeline import make_pipeline
def train_model(X,y,X_test,folds,model):
pipeline = make_pipeline(StandardScaler(), model)
...
然后使用pipeline
代替model
. 在每次调用fit
时predict
,它都会自动标准化手头的数据。
请注意,您还可以使用scikit-learn 中的cross_val_scorescoring='roc_auc'
函数,参数为.
推荐阅读
- javascript - ReactJS:表单数据捕获:值设置不正确
- python - 如何删除文本文件中的特定昵称和密码?
- go - 无法使用 mikepb 的 go-serial 打开串口
- c++ - 清除输入缓冲区后未提取字符串流
- java - 导入 de.codecentric.boot.admin.server.config.EnableAdminServer 无法解析
- php - 使用 selected 在 php 中获取多选下拉值
- delphi - 在 Delphi 中访问其他单位常量
- javascript - Chart JS Tooltip - 将其放置在画布外的固定位置
- c# - c# regex - 匹配一个单词或一个数字(int 或 float)
- vb.net - Visual Basic Else If 语句