python - 多项式 logit 在交叉验证的某些折叠中返回 nans
问题描述
我制作了这段代码,它使用分层 kfolds 来分割数据集并拟合多项式回归,然后获得准确性。MyX
是一个有 19 个变量的数组,(最后一个是聚类变量),并且Y
有 3 个类(0,1,2)。
X = np.asarray(df[[*all 19 columns here*]], dtype="float64")
y = np.asarray(df["categoric_var"], dtype="int")
acc_test=[]
acc_train=[]
skf = StratifiedKFold(n_splits=5, shuffle=True)
split_n = 0
for train_ix, test_ix in skf.split(X,y):
split_n +=1
X_train, X_valid = X[train_ix], X[test_ix]
y_train, y_valid = y[train_ix], y[test_ix]
cluster_groups = X_train[:,-1]
X_train2 = X_train[:,:-1].astype("float64") # remove clustering variable
X_valid2 = X_valid[:,:-1].astype("float64") # remove clustering variable
mnl = sm.MNLogit(y_train, X_train2).fit(cov_type="cluster", cov_kwds={"groups":cluster_groups})
print(mnl.summary())
train_pred = mnl.predict(X_train2)
# turn predicted probabilities into final classification, into a list
pred_list_train = []
for row in train_pred:
if np.where(row == np.amax(row))[0]==0:
pred_list_train.append(0)
elif np.where(row == np.amax(row))[0]==1:
pred_list_train.append(1)
else:
pred_list_train.append(2)
print('MNLogit Regression, training set, fold ', i, ': ', classification_report(y_train, pred_list_train))
pred = mnl.predict(X_valid2)
# turn predicted probabilities into final classification, into a list
pred_list_test = []
for row in pred:
if np.where(row == np.amax(row))[0]==0:
pred_list_test.append(0)
elif np.where(row == np.amax(row))[0]==1:
pred_list_test.append(1)
else:
pred_list_test.append(2)
#Measure of the fit of the model
print('MNLogit Regression, validation set, fold ', i, ': ', classification_report(y_valid, pred_list_test))
acc_test.append(accuracy_score(y_valid, pred_list_test))
acc_train.append(accuracy_score(y_train, pred_list_train))
问题是我有两个版本y
,一个是类更不平衡(版本 1),另一个是更平衡(版本 2)。
当我使用版本 1 尝试此代码时y
,它可以完美运行。但是,当我尝试使用版本 2 运行它时,一些折叠nan
会在回归中全部返回……这是一个示例(为长度道歉)。这是前两个折叠的结果:
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2251: RuntimeWarning: divide by zero encountered in log
logprob = np.log(self.cdf(np.dot(self.exog,params)))
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2252: RuntimeWarning: invalid value encountered in multiply
return np.sum(d * logprob)
Optimization terminated successfully.
Current function value: nan
Iterations 14
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in greater
return (a < x) & (x < b)
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in less
return (a < x) & (x < b)
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:1912: RuntimeWarning: invalid value encountered in less_equal
cond2 = cond0 & (x <= _a)
MNLogit Regression Results
==============================================================================
Dep. Variable: y No. Observations: 13852
Model: MNLogit Df Residuals: 13814
Method: MLE Df Model: 36
Date: Thu, 13 Aug 2020 Pseudo R-squ.: nan
Time: 23:04:09 Log-Likelihood: nan
converged: True LL-Null: -13943.
Covariance Type: cluster LLR p-value: nan
==============================================================================
y=1 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 -0.0012 0.009 -0.126 0.900 -0.020 0.017
x2 0.0001 1.8e-05 6.207 0.000 7.63e-05 0.000
x3 -0.6074 0.621 -0.978 0.328 -1.825 0.610
x4 8.5373 1.219 7.004 0.000 6.148 10.926
x5 0.0136 0.002 5.906 0.000 0.009 0.018
x6 0.0024 0.066 0.037 0.970 -0.127 0.131
x7 -0.0060 0.003 -1.972 0.049 -0.012 -3.76e-05
x8 -0.0263 0.015 -1.695 0.090 -0.057 0.004
x9 -0.0237 0.026 -0.926 0.355 -0.074 0.026
x10 -0.0008 0.002 -0.404 0.686 -0.005 0.003
x11 0.0713 0.031 2.308 0.021 0.011 0.132
x12 -9.272e-05 1.54e-05 -6.003 0.000 -0.000 -6.24e-05
x13 -0.0012 0.000 -4.696 0.000 -0.002 -0.001
x14 5.53e-05 1.06e-05 5.215 0.000 3.45e-05 7.61e-05
x15 -0.0007 0.000 -3.538 0.000 -0.001 -0.000
x16 7.334e-05 6.94e-05 1.056 0.291 -6.27e-05 0.000
x17 -0.0098 0.001 -9.659 0.000 -0.012 -0.008
x18 -0.0506 0.036 -1.409 0.159 -0.121 0.020
x19 0.0953 0.017 5.682 0.000 0.062 0.128
------------------------------------------------------------------------------
y=2 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 0.0354 0.025 1.411 0.158 -0.014 0.084
x2 0.0003 0.000 1.996 0.046 5.62e-06 0.001
x3 3.3663 3.177 1.060 0.289 -2.860 9.593
x4 16.6473 8.483 1.962 0.050 0.021 33.273
x5 0.0507 0.026 1.963 0.050 7.82e-05 0.101
x6 0.3423 0.278 1.232 0.218 -0.202 0.887
x7 0.0274 0.026 1.051 0.293 -0.024 0.079
x8 0.0998 0.071 1.397 0.162 -0.040 0.240
x9 -0.0231 0.049 -0.466 0.641 -0.120 0.074
x10 0.0126 0.006 1.969 0.049 5.65e-05 0.025
x11 0.2219 0.129 1.720 0.085 -0.031 0.475
x12 -0.0002 8.6e-05 -2.286 0.022 -0.000 -2.8e-05
x13 -0.0022 0.001 -2.591 0.010 -0.004 -0.001
x14 0.0001 5.35e-05 2.313 0.021 1.89e-05 0.000
x15 -0.0018 0.001 -2.209 0.027 -0.003 -0.000
x16 6.439e-05 0.000 0.468 0.640 -0.000 0.000
x17 -0.8636 0.047 -18.523 0.000 -0.955 -0.772
x18 1.7166 4.104 0.418 0.676 -6.328 9.761
x19 0.0713 0.052 1.375 0.169 -0.030 0.173
==============================================================================
MNLogit Regression, training set, fold 21 : precision recall f1-score support
0 0.89 0.78 0.83 3679
1 0.76 0.83 0.80 2738
2 0.97 1.00 0.98 7435
accuracy 0.91 13852
macro avg 0.87 0.87 0.87 13852
weighted avg 0.91 0.91 0.90 13852
MNLogit Regression, validation set, fold 21 : precision recall f1-score support
0 0.88 0.78 0.83 920
1 0.77 0.82 0.79 685
2 0.97 1.00 0.98 1859
accuracy 0.90 3464
macro avg 0.87 0.86 0.87 3464
weighted avg 0.90 0.90 0.90 3464
shape xtrain: (13853, 19)
shape ytrain: (13853,)
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2219: RuntimeWarning: overflow encountered in exp
eXB = np.column_stack((np.ones(len(X)), np.exp(X)))
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2220: RuntimeWarning: invalid value encountered in true_divide
return eXB/eXB.sum(1)[:,None]
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\optimizer.py:300: RuntimeWarning: invalid value encountered in greater
oldparams) > tol)):
Optimization terminated successfully.
Current function value: nan
Iterations 6
MNLogit Regression Results
==============================================================================
Dep. Variable: y No. Observations: 13853
Model: MNLogit Df Residuals: 13815
Method: MLE Df Model: 36
Date: Thu, 13 Aug 2020 Pseudo R-squ.: nan
Time: 23:04:10 Log-Likelihood: nan
converged: True LL-Null: -13944.
Covariance Type: cluster LLR p-value: nan
==============================================================================
y=1 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 nan nan nan nan nan nan
x2 nan nan nan nan nan nan
x3 nan nan nan nan nan nan
x4 nan nan nan nan nan nan
x5 nan nan nan nan nan nan
x6 nan nan nan nan nan nan
x7 nan nan nan nan nan nan
x8 nan nan nan nan nan nan
x9 nan nan nan nan nan nan
x10 nan nan nan nan nan nan
x11 nan nan nan nan nan nan
x12 nan nan nan nan nan nan
x13 nan nan nan nan nan nan
x14 nan nan nan nan nan nan
x15 nan nan nan nan nan nan
x16 nan nan nan nan nan nan
x17 nan nan nan nan nan nan
x18 nan nan nan nan nan nan
x19 nan nan nan nan nan nan
------------------------------------------------------------------------------
y=2 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 nan nan nan nan nan nan
x2 nan nan nan nan nan nan
x3 nan nan nan nan nan nan
x4 nan nan nan nan nan nan
x5 nan nan nan nan nan nan
x6 nan nan nan nan nan nan
x7 nan nan nan nan nan nan
x8 nan nan nan nan nan nan
x9 nan nan nan nan nan nan
x10 nan nan nan nan nan nan
x11 nan nan nan nan nan nan
x12 nan nan nan nan nan nan
x13 nan nan nan nan nan nan
x14 nan nan nan nan nan nan
x15 nan nan nan nan nan nan
x16 nan nan nan nan nan nan
x17 nan nan nan nan nan nan
x18 nan nan nan nan nan nan
x19 nan nan nan nan nan nan
==============================================================================
__main__:42: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
__main__:44: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
__main__:54: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
__main__:56: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
MNLogit Regression, training set, fold 21 : precision recall f1-score support
0 0.00 0.00 0.00 3679
1 0.00 0.00 0.00 2739
2 0.54 1.00 0.70 7435
accuracy 0.54 13853
macro avg 0.18 0.33 0.23 13853
weighted avg 0.29 0.54 0.37 13853
MNLogit Regression, validation set, fold 21 : precision recall f1-score support
0 0.00 0.00 0.00 920
1 0.00 0.00 0.00 684
2 0.54 1.00 0.70 1859
accuracy 0.54 3463
macro avg 0.18 0.33 0.23 3463
weighted avg 0.29 0.54 0.38 3463
我不知道这里会发生什么,因为没有真正改变,只有因变量中的值。
解决方案
推荐阅读
- go - 使用协议缓冲区错误进行构建:结构初始化程序中的值太少
- google-bigquery - Google Data Studio 无法探索 Big Query 时间分区表
- excel - EXCEL 数据验证 ISBLANK 不起作用
- python - 在电子邮件中删除 python 中的 nan
- mirth - 如何为欢乐添加 Hprim 支持?
- php - 有效的PHP数组和选择
- selenium - 从 Testng 运行测试组时在 Testlistener 中获取 java.lang.ClasscastException
- snakemake - Snakemake 找不到输出文件,在延迟等待似乎被忽略时给出 MissingOutputException
- jquery - 如何使用 Ajax 和 Jquery 读取 XML 响应并在网页中显示它?
- c++ - 使用 gcc 3.2.3 链接使用 g++ 7.2 编译的静态库