首页 > 解决方案 > Python SVM OVO 打破关系

问题描述

我正在尝试在 Python 中手动实现线性 SVM 解决方案。

def train_OvO(X_train, y_train, train_func, param):

classes = sorted(set(y_train))
estimators = dict()
for i,ci in enumerate(classes):
    for j,cj in enumerate(classes):
        if j>i:
            X = X_train.copy()
            X = X[np.logical_or(y_train==ci,y_train==cj)]
            y = y_train.copy()
            y = y[np.logical_or(y_train==ci,y_train==cj)]
            
            yp = y.copy()
            
            yp[y==ci] = 1
            yp[y==cj] = -1
            est = train_func(X, yp, param)
            
            estimators[(i,j)] = est
            
return estimators

我已经实现了上述函数来为 One-Vs_One SVM 训练多个分类器。为了测试它们,我实现了以下代码

def test_OvO(X_test, test_func, score_func, estimators):
    
    #all_scores = np.zeros((X_test.shape[0], len(estimators)))
    #score_list = list()
    
    
    all_scores = np.zeros((X_test.shape[0], len(estimators)))
    for i,j in estimators:
        est = estimators[(i,j)]
        preds = test_func(X_test, est)
        all_scores[:,i][preds==1]+=1
        all_scores[:,j][preds==-1]+=1
    preds = np.argmax(all_scores, axis=1)
    return preds

然而,该算法并没有打破关系。在问题的定义中,算法应该根据 ;

“如果出现平局(在所有获得相同票数的案例中),它会通过对底层二元分类器计算的成对分类置信度水平求和来选择具有最高总分类置信度的类。”

据我了解,应该在函数中完成平局,test_OvO但我可能是错的。任何帮助,将不胜感激。

标签: pythonnumpysvm

解决方案


推荐阅读