首页 > 解决方案 > 使用固定输入变量进行回归预测的等高线图

问题描述

我想为具有多个特征的预测创建等高线图。其余值应固定以绘制 2 个有趣的值。不幸的是,我得到的矩阵在所有位置上都具有相同的值,而不是预期的值。

我认为我的矩阵有问题,但我没有发现错误。

[...]
f_learn = [x_1,x_2,x_3,x_4]
r_lear = [r_1]

clf = svm.MLPRegressor(...)
clf.fit(f_learn,r_learn)
[...]

x_1 = np.linspace(1, 100, 100)
x_2 = np.linspace(1, 100, 100)
X_1, X_2 = np.meshgrid(x_1, x_2)

x_3 = np.full( (100,100), 5).ravel()
x_4 = np.full( (100,100), 15).ravel()

predict_matrix = np.vstack([X_1.ravel(), X_2.ravel(), x_3,x_4])
prediction = clf.predict(predict_matrix.T)

prediction_plot = prediction.reshape(X_1.shape)

plt.figure()
    cp = plt.contourf(X_1, X_2, prediction_plot, 10)
    plt.colorbar(cp)
    plt.show()

如果我手动逐行测试矩阵,我会得到正确的结果。但是,如果我以这种方式将它们放在一起,它将不起作用。

编辑:复制代码时出错

数据示例。所有答案都是 7.5 并且没有不同;(

import matplotlib.pyplot as plt
import numpy as np
from sklearn import linear_model

f_learn =  np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
r_learn =  np.array([6,7,8,9])

reg = linear_model.LinearRegression()
reg.fit (f_learn, r_learn)

x_1 = np.linspace(0, 20, 10)
x_2 = np.linspace(0, 20, 10)
X_1, X_2 = np.meshgrid(x_1, x_2)

x_3 = np.full( (10,10), 5).ravel()
x_4 = np.full( (10,10), 2).ravel()

predict_matrix = np.vstack([X_1.ravel(), X_2.ravel(), x_3, x_4])
prediction = reg.predict(predict_matrix.T)

prediction_plot = prediction.reshape(X_1.shape)

plt.figure()
cp = plt.contourf(X_1, X_2, prediction_plot, 10)
plt.colorbar(cp)
plt.show()

结果

标签: pythonmatplotlibmachine-learningscikit-learn

解决方案


尝试这样的事情。代码中的一些注释

x_1 = np.linspace(1, 100, 100)
x_2 = np.linspace(1, 100, 100)
X_1, X_2 = np.meshgrid(x_1, x_2)

# Why the shape was (1000, 100)?
x_3 = np.full((100, 100), 5).ravel() 
x_4 = np.full((100, 100),  15).ravel()

# you should use X_1.ravel() to make it column vector (it is one feature)
# there was x_3 insted of x_4
predict_matrix = np.vstack([X_1.ravel(), X_2.ravel(), x_3, x_4])  
prediction = clf.predict(predict_matrix.T)

prediction_plot = prediction.reshape(X_1.shape)

plt.figure()
cp = plt.contourf(X_1, X_2, prediction_plot, 10)
plt.colorbar(cp)
plt.show()

推荐阅读