首页 > 解决方案 > 多项式 logit 在交叉验证的某些折叠中返回 nans

问题描述

我制作了这段代码,它使用分层 kfolds 来分割数据集并拟合多项式回归,然后获得准确性。MyX是一个有 19 个变量的数组,(最后一个是聚类变量),并且Y有 3 个类(0,1,2)。

X = np.asarray(df[[*all 19 columns here*]], dtype="float64")
y = np.asarray(df["categoric_var"], dtype="int")

acc_test=[]
acc_train=[]
skf = StratifiedKFold(n_splits=5, shuffle=True)
split_n = 0

for train_ix, test_ix in skf.split(X,y):
    split_n +=1
    X_train, X_valid = X[train_ix], X[test_ix]
    y_train, y_valid = y[train_ix], y[test_ix]
    cluster_groups = X_train[:,-1]
    X_train2 = X_train[:,:-1].astype("float64") # remove clustering variable
    X_valid2 = X_valid[:,:-1].astype("float64") # remove clustering variable

    mnl = sm.MNLogit(y_train, X_train2).fit(cov_type="cluster", cov_kwds={"groups":cluster_groups})
    print(mnl.summary())
    train_pred = mnl.predict(X_train2)

    # turn predicted probabilities into final classification, into a list
    pred_list_train = []
    for row in train_pred:
        if np.where(row == np.amax(row))[0]==0:
            pred_list_train.append(0)
        elif np.where(row == np.amax(row))[0]==1:
            pred_list_train.append(1)
        else:
            pred_list_train.append(2)

    print('MNLogit Regression, training set, fold ', i, ': ', classification_report(y_train, pred_list_train))
    
    pred = mnl.predict(X_valid2)

    # turn predicted probabilities into final classification, into a list
    pred_list_test = []
    for row in pred:
        if np.where(row == np.amax(row))[0]==0:
            pred_list_test.append(0)
        elif np.where(row == np.amax(row))[0]==1:
            pred_list_test.append(1)
        else:
            pred_list_test.append(2)

    #Measure of the fit of the model

    print('MNLogit Regression, validation set, fold ', i, ': ', classification_report(y_valid, pred_list_test))

    acc_test.append(accuracy_score(y_valid, pred_list_test))
    acc_train.append(accuracy_score(y_train, pred_list_train))

问题是我有两个版本y,一个是类更不平衡(版本 1),另一个是更平衡(版本 2)。

当我使用版本 1 尝试此代码时y,它可以完美运行。但是,当我尝试使用版本 2 运行它时,一些折叠nan会在回归中全部返回……这是一个示例(为长度道歉)。这是前两个折叠的结果:

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2251: RuntimeWarning: divide by zero encountered in log

  logprob = np.log(self.cdf(np.dot(self.exog,params)))

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2252: RuntimeWarning: invalid value encountered in multiply

  return np.sum(d * logprob)

Optimization terminated successfully.

         Current function value: nan

         Iterations 14

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in greater

  return (a < x) & (x < b)

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in less

  return (a < x) & (x < b)

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:1912: RuntimeWarning: invalid value encountered in less_equal

  cond2 = cond0 & (x <= _a)

                          MNLogit Regression Results                          

==============================================================================

Dep. Variable:                      y   No. Observations:                13852

Model:                        MNLogit   Df Residuals:                    13814

Method:                           MLE   Df Model:                           36

Date:                Thu, 13 Aug 2020   Pseudo R-squ.:                     nan

Time:                        23:04:09   Log-Likelihood:                    nan

converged:                       True   LL-Null:                       -13943.

Covariance Type:              cluster   LLR p-value:                       nan

==============================================================================

       y=1       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1            -0.0012      0.009     -0.126      0.900      -0.020       0.017

x2             0.0001    1.8e-05      6.207      0.000    7.63e-05       0.000

x3            -0.6074      0.621     -0.978      0.328      -1.825       0.610

x4             8.5373      1.219      7.004      0.000       6.148      10.926

x5             0.0136      0.002      5.906      0.000       0.009       0.018

x6             0.0024      0.066      0.037      0.970      -0.127       0.131

x7            -0.0060      0.003     -1.972      0.049      -0.012   -3.76e-05

x8            -0.0263      0.015     -1.695      0.090      -0.057       0.004

x9            -0.0237      0.026     -0.926      0.355      -0.074       0.026

x10           -0.0008      0.002     -0.404      0.686      -0.005       0.003

x11            0.0713      0.031      2.308      0.021       0.011       0.132

x12        -9.272e-05   1.54e-05     -6.003      0.000      -0.000   -6.24e-05

x13           -0.0012      0.000     -4.696      0.000      -0.002      -0.001

x14          5.53e-05   1.06e-05      5.215      0.000    3.45e-05    7.61e-05

x15           -0.0007      0.000     -3.538      0.000      -0.001      -0.000

x16         7.334e-05   6.94e-05      1.056      0.291   -6.27e-05       0.000

x17           -0.0098      0.001     -9.659      0.000      -0.012      -0.008

x18           -0.0506      0.036     -1.409      0.159      -0.121       0.020

x19            0.0953      0.017      5.682      0.000       0.062       0.128

------------------------------------------------------------------------------

       y=2       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1             0.0354      0.025      1.411      0.158      -0.014       0.084

x2             0.0003      0.000      1.996      0.046    5.62e-06       0.001

x3             3.3663      3.177      1.060      0.289      -2.860       9.593

x4            16.6473      8.483      1.962      0.050       0.021      33.273

x5             0.0507      0.026      1.963      0.050    7.82e-05       0.101

x6             0.3423      0.278      1.232      0.218      -0.202       0.887

x7             0.0274      0.026      1.051      0.293      -0.024       0.079

x8             0.0998      0.071      1.397      0.162      -0.040       0.240

x9            -0.0231      0.049     -0.466      0.641      -0.120       0.074

x10            0.0126      0.006      1.969      0.049    5.65e-05       0.025

x11            0.2219      0.129      1.720      0.085      -0.031       0.475

x12           -0.0002    8.6e-05     -2.286      0.022      -0.000    -2.8e-05

x13           -0.0022      0.001     -2.591      0.010      -0.004      -0.001

x14            0.0001   5.35e-05      2.313      0.021    1.89e-05       0.000

x15           -0.0018      0.001     -2.209      0.027      -0.003      -0.000

x16         6.439e-05      0.000      0.468      0.640      -0.000       0.000

x17           -0.8636      0.047    -18.523      0.000      -0.955      -0.772

x18            1.7166      4.104      0.418      0.676      -6.328       9.761

x19            0.0713      0.052      1.375      0.169      -0.030       0.173

==============================================================================

MNLogit Regression, training set, fold  21 :                precision    recall  f1-score   support

 

           0       0.89      0.78      0.83      3679

           1       0.76      0.83      0.80      2738

           2       0.97      1.00      0.98      7435

 

    accuracy                           0.91     13852

   macro avg       0.87      0.87      0.87     13852

weighted avg       0.91      0.91      0.90     13852

 

MNLogit Regression, validation set, fold  21 :                precision    recall  f1-score   support

 

           0       0.88      0.78      0.83       920

           1       0.77      0.82      0.79       685

           2       0.97      1.00      0.98      1859

 

    accuracy                           0.90      3464

   macro avg       0.87      0.86      0.87      3464

weighted avg       0.90      0.90      0.90      3464

 

shape xtrain:  (13853, 19)

shape ytrain:  (13853,)

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2219: RuntimeWarning: overflow encountered in exp

  eXB = np.column_stack((np.ones(len(X)), np.exp(X)))

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2220: RuntimeWarning: invalid value encountered in true_divide

  return eXB/eXB.sum(1)[:,None]

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\optimizer.py:300: RuntimeWarning: invalid value encountered in greater

  oldparams) > tol)):

Optimization terminated successfully.

         Current function value: nan

         Iterations 6

                          MNLogit Regression Results                         

==============================================================================

Dep. Variable:                      y   No. Observations:                13853

Model:                        MNLogit   Df Residuals:                    13815

Method:                           MLE   Df Model:                           36

Date:                Thu, 13 Aug 2020   Pseudo R-squ.:                     nan

Time:                        23:04:10   Log-Likelihood:                    nan

converged:                       True   LL-Null:                       -13944.

Covariance Type:              cluster   LLR p-value:                       nan

==============================================================================

       y=1       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1                nan        nan        nan        nan         nan         nan

x2                nan        nan        nan        nan         nan         nan

x3                nan        nan        nan        nan         nan         nan

x4                nan        nan        nan        nan         nan         nan

x5                nan        nan        nan        nan         nan         nan

x6                nan        nan        nan        nan         nan         nan

x7                nan        nan        nan        nan         nan         nan

x8                nan        nan        nan        nan         nan         nan

x9                nan        nan        nan        nan         nan         nan

x10               nan        nan        nan        nan         nan         nan

x11               nan        nan        nan        nan         nan         nan

x12               nan        nan        nan        nan         nan         nan

x13               nan        nan        nan        nan         nan         nan

x14               nan        nan        nan        nan         nan         nan

x15               nan        nan        nan        nan         nan         nan

x16               nan        nan        nan        nan         nan         nan

x17               nan        nan        nan        nan         nan         nan

x18               nan        nan        nan        nan         nan         nan

x19               nan        nan        nan        nan         nan         nan

------------------------------------------------------------------------------

       y=2       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1                nan        nan        nan        nan         nan         nan

x2                nan        nan        nan        nan         nan         nan

x3                nan        nan        nan        nan         nan         nan

x4                nan        nan        nan        nan         nan         nan

x5                nan        nan        nan        nan         nan         nan

x6                nan        nan        nan        nan         nan         nan

x7                nan        nan        nan        nan         nan         nan

x8                nan        nan        nan        nan         nan         nan

x9                nan        nan        nan        nan         nan         nan

x10               nan        nan        nan        nan         nan         nan

x11               nan        nan        nan        nan         nan         nan

x12               nan        nan        nan        nan         nan         nan

x13               nan        nan        nan        nan         nan         nan

x14               nan        nan        nan        nan         nan         nan

x15               nan        nan        nan        nan         nan         nan

x16               nan        nan        nan        nan         nan         nan

x17               nan        nan        nan        nan         nan         nan

x18               nan        nan        nan        nan         nan         nan

x19               nan        nan        nan        nan         nan         nan

==============================================================================

__main__:42: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

__main__:44: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average, modifier, msg_start, len(result))

__main__:54: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

__main__:56: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average, modifier, msg_start, len(result))

MNLogit Regression, training set, fold  21 :                precision    recall  f1-score   support

 

           0       0.00      0.00      0.00      3679

           1       0.00      0.00      0.00      2739

           2       0.54      1.00      0.70      7435

 

    accuracy                           0.54     13853

   macro avg       0.18      0.33      0.23     13853

weighted avg       0.29      0.54      0.37     13853

 

MNLogit Regression, validation set, fold  21 :                precision    recall  f1-score   support

 

           0       0.00      0.00      0.00       920

           1       0.00      0.00      0.00       684

           2       0.54      1.00      0.70      1859

 

    accuracy                           0.54      3463

   macro avg       0.18      0.33      0.23      3463

weighted avg       0.29      0.54      0.38      3463

我不知道这里会发生什么,因为没有真正改变,只有因变量中的值。

标签: pythonregressioncross-validation

解决方案


推荐阅读