首页 > 解决方案 > 为什么 def 函数生成的图形与正常的代码块(python)不同?

问题描述

谁能看看下面的代码并告诉我为什么这两个图表的 x 轴格式不同以及如何调整def 函数中的代码以使两个图表具有相同的布局?范围是相同的,只是混淆了为什么它们以不同的方式表达。我无法发布任何图片作为首发。任何不便敬请谅解。

dtree = DecisionTreeClassifier(random_state=id, criterion='gini')

tree_depth = np.arange(1, 16)
train_scores, test_scores = validation_curve(dtree, x_train, y_train,
                                             param_name='max_depth', param_range=tree_depth,
                                             scoring='accuracy', cv=5)
plt.figure()
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

plt.plot(tree_depth, train_scores_mean, 'o-', color='red', label='training score')
plt.fill_between(tree_depth, train_scores_mean - train_scores_std,
                 train_scores_mean + train_scores_std, alpha=0.1, color='red')

plt.plot(tree_depth, test_scores_mean, 'o-', color='blue', label='validation score')
plt.fill_between(tree_depth, test_scores_mean - test_scores_std,
                 test_scores_mean + test_scores_std, alpha=0.1, color='blue')

plt.title('Validation Curve: Decision Tree for Cancer')
plt.xlabel('tree_depth')
plt.ylabel('accuracy score')
plt.legend(loc='best')
plt.grid()

xlim 显示为 [2 4 6 8 10 12 14]

def validationcurve(interval, estimator, x, y, parameter, title):
    train_scores, test_scores = validation_curve(estimator, x, y, param_name=parameter, param_range=interval, scoring='accuracy', n_jobs=4, cv=5)

    plt.figure()
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)

    plt.semilogx(interval, train_scores_mean, 'o-', color='r', label='Training score')
    plt.fill_between(interval, train_scores_mean - train_scores_std,
                     train_scores_mean + train_scores_std, alpha=0.1, color='r')
    plt.semilogx(interval, test_scores_mean, 'o-', color='b', label='Validation score')
    plt.fill_between(interval, test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, alpha=0.1, color='b')

    plt.title(title)
    plt.xlabel(parameter)
    plt.ylabel("Score")
    plt.legend(loc="best")
    plt.grid()
    plt.savefig(fig_path + title+'.png')
    plt.clf()

validationcurve(np.arange(1, 16), dtree, x_train, y_train, parameter='max_depth', title='dtree_validation_curve_max_depth')

xlim 显示为 [10^0, 10^1]

标签: pythonlayoutcharts

解决方案


推荐阅读