python - 将 AUC 用于 DecisionTreeClassifier 时出错 - Python
问题描述
我正在尝试调整我的DecisionTreeClassifier
. 我尝试使用 AUC(曲线下面积)作为评估指标。这是我的代码:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
def max_depth_prediction(X_train, y_train, X_test, y_test, y):
max_depths = np.linspace(1, 32, 32, endpoint=True)
train_results = []
test_results = []
for max_depth in max_depths:
dt = DecisionTreeClassifier(max_depth=max_depth)
dt.fit(X_train, y_train)
train_pred = dt.predict(X_train)
print(y_train)
print(train_pred)
false_positive_rate, true_positive_rate, thresholds = roc_curve(y_train.astype(int), train_pred.astype(int))
roc_auc = auc(false_positive_rate, true_positive_rate)
# Add auc score to previous train results
train_results.append(roc_auc)
y_pred = dt.predict(X_test)
false_positive_rate, true_positive_rate, thresholds = roc_curve(y_test, y_pred)
roc_auc = auc(false_positive_rate, true_positive_rate)
# Add auc score to previous test results
test_results.append(roc_auc)
但是我在使用它时遇到错误:
ValueError: y_true takes value in {'0', '1'} and pos_label is not specified: either make y_true take value in {0, 1} or {-1, 1} or pass pos_label explicitly.
我检查了我的 2 个向量,它们似乎很好:
y_train = ['0' '0' '0' ... '1' '1' '0']
train_pred = ['0' '0' '1' ... '1' '1' '0']
解决方案
y_train
并且train_pred
是字符串列表,它们应该是整数。试试这个
def max_depth_prediction(X_train, y_train, X_test, y_test, y):
max_depths = np.linspace(1, 32, 32, endpoint=True)
train_results = []
test_results = []
for max_depth in max_depths:
dt = DecisionTreeClassifier(max_depth=max_depth)
dt.fit(X_train, y_train)
train_pred = dt.predict(X_train)
train_pred = [int(i) for i in train_pred]
y_train = [int(i) for i in y_train]
print(y_train)
print(train_pred)
false_positive_rate, true_positive_rate, thresholds = roc_curve(y_train.astype(int), train_pred.astype(int))
roc_auc = auc(false_positive_rate, true_positive_rate)
# Add auc score to previous train results
train_results.append(roc_auc)
y_pred = dt.predict(X_test)
false_positive_rate, true_positive_rate, thresholds = roc_curve(y_test, y_pred)
roc_auc = auc(false_positive_rate, true_positive_rate)
# Add auc score to previous test results
test_results.append(roc_auc)
推荐阅读
- r - R上的双变量概率问题
- go - Golang grpc:如何判断服务器何时开始监听?
- android - Firebase 无密码登录 + Flutter
- breeze - EntityError 上的自定义属性
- orbeon - Obreon:屏幕阅读器不读取下拉项目
- python - 与女服务员和码头工人一起提供烧瓶应用程序
- javascript - 在运行时读取 CSS 属性值
- c++ - C++ 中的 Gimp 插件:如何获得用户的输入?
- java - 无法使用 orderByChild 方法对来自 firebase 的数据进行排序
- javascript - Google 表格中的 JDBC 连接