python - sklearn 的 MLP predict_proba 函数在内部是如何工作的?
问题描述
我试图了解sklearn
MLP 分类器如何为其predict_proba
功能检索其结果。
该网站仅列出:
概率估计
而许多其他的,例如逻辑回归,有更详细的答案:概率估计。
所有类的返回估计值按类标签排序。
对于 multi_class 问题,如果 multi_class 设置为“多项式”,则 softmax 函数用于查找每个类的预测概率。否则使用one-vs-rest 方法,即使用逻辑函数计算每个类假设它为正的概率。并在所有类中标准化这些值。
其他模型类型也有更多细节。以支持向量机分类器为例
还有这篇非常不错的 Stack Overflow 帖子,它深入解释了它。
计算 X 中样本的可能结果的概率。
模型需要在训练时计算概率信息:拟合属性概率设置为 True。
其他例子
随机森林:
预测 X 的类别概率。
输入样本的预测类别概率计算为森林中树木的平均预测类别概率。一棵树的类概率是叶子中同一类的样本的分数。
我希望了解与上述帖子相同的内容,但对于MLPClassifier
. 内部工作如何MLPClassifier
?
解决方案
查看源代码,我发现:
def _initialize(self, y, layer_units):
# set all attributes, allocate weights etc for first call
# Initialize parameters
self.n_iter_ = 0
self.t_ = 0
self.n_outputs_ = y.shape[1]
# Compute the number of layers
self.n_layers_ = len(layer_units)
# Output for regression
if not is_classifier(self):
self.out_activation_ = 'identity'
# Output for multi class
elif self._label_binarizer.y_type_ == 'multiclass':
self.out_activation_ = 'softmax'
# Output for binary class and multi-label
else:
self.out_activation_ = 'logistic'
似乎 MLP 分类器使用逻辑函数进行二元分类,使用 softmax 函数进行多标签分类,以构建输出层。这表明网络的输出是一个概率向量,网络在此基础上推导出预测。
如果我看predict_proba
方法:
def predict_proba(self, X):
"""Probability estimates.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input data.
Returns
-------
y_prob : ndarray of shape (n_samples, n_classes)
The predicted probability of the sample for each class in the
model, where classes are ordered as they are in `self.classes_`.
"""
check_is_fitted(self)
y_pred = self._predict(X)
if self.n_outputs_ == 1:
y_pred = y_pred.ravel()
if y_pred.ndim == 1:
return np.vstack([1 - y_pred, y_pred]).T
else:
return y_pred
这确认了 softmax 或逻辑作为输出层的激活函数的作用,以便获得概率向量。
希望这可以帮助你。
推荐阅读
- android - Android测试用例等待10分钟后才能在debug模式下调试
- python - scikit learn ExtraTreesClassifier 预测使用 Pandas DataFarme vs datatale Frame vs Numpy array 给出不同的执行时间
- vim - 尝试(但失败)让 cscope/ctags 在混合 C/C++ 项目中定位 C++ 函数
- html - Nikola:添加带有 id 的链接
- java - (GAE-Standard+Java11) 运行多个实例的会话
- python - 在 Pandas Dataframe 中计算时间间隔内的行数
- grails - Grails 视图为 DTO 对象列表呈现额外的逗号
- c# - 在 Azure 搜索中将模糊搜索与同义词扩展相结合
- regex - 我需要帮助将公式应用于 Google 工作表中的每一行
- python - 有 python 的 is_Prime() 函数