machine-learning - 试图通过手动实现来理解 SVC RBF 内核的决策函数,但没有得到与 sklearns 决策函数相同的输出
问题描述
X, y = make_classification(n_samples=5000, n_features=5, n_redundant=2,
n_classes=2, weights=[0.7], class_sep=0.7, random_state=15)
xtrain,xtest,ytrain,ytest=train_test_split(X,y,test_size=0.2)
xtr,xcv,ytr,ycv=train_test_split(xtrain,ytrain,test_size=0.25)
clf = SVC(gamma=0.001, C=100.)
clf.fit(xtr, ytr)
intercept=clf.intercept_
alpha=clf.dual_coef_
ind=clf.support_
ytemp=ytrain[ind]
Yi=np.where(ytemp==0,-1,ytemp)
找到两点之间的距离平方
def dist(x1,x2):
s=0
for i in range(len(x1)):
s+=(x1[i]-x2[i])**2
return s
实现决策功能
def decision_function(Xcv,xi,yi,intercept,a):
df=[]
for in Xcv:
s=0
for i in range(len(xi)):
s=s+(a[i]*math.(-(0.001)*(dist(xq,xi[i])))+intercept[0])
print(a[i]*math.(-(0.001)*(dist(xq,xi[i])))+intercept[0])
print("sum={} , i={}".format(s,i))
df.append(s)
return df
fcv = decision_function(xcv,Xi,ytemp,intercept,alpha[0])
对于这个决策函数,我得到的输出为
[577.2867160208668, 579.7215603806541, 578.5782187273019, 580.2360508825304, 577.8387580749602, 581.2748889203071, 578.0276237732187, 575.7743283778641, 578.0748474695522, 577.9968339333016, 577.2826299382772, 580.1755085241488, 580.2985317482837, 579.4260145851334, 578.9086035688148, 578.0971144092871, 577.7570129104452, 581.1748250826266, 577.4214599360627, 577.2634002760278, 578.1149032333689, 579.6387639355653, 576.9266069070123, 578.4634016633581, 578.2309365554806, 577.4782245762461, 577.3150477775968, 576.9916402698525, 576.5657957818951, 577.7553985428359, 576.662975210225, 577.9975979302852, 577.2010018149546, 577.4099961691077, 579.86257415243614
输出中只提到了顶点
ypred1=clf.decision_function(xcv)
对于 sklearn 的上述功能,获取输出为
[-2.97173810e+00 2.16713203e+00 -6.05072299e-01 1.49020826e+00
-3.98455298e+00 1.57312742e+00 -2.88261984e+00 -1.43540335e+00
-2.89577440e+00 -3.37822755e+00 1.36767728e+00 1.38654851e+00
7.05173444e-01 1.93773096e+00 -1.63107457e+00 -3.23433464e+00
-3.61976997e+00 2.17733962e+00 -2.47422660e+00 -6.75162425e+00
1.60666547e+00 1.88396451e+00 -2.48529368e+00 -1.88689390e+00
-3.60426066e+00 -2.02841572e+00 1.53602071e+00 -3.12872962e+00
1.53430434e+00 -3.12180091e+00 -2.35210860e+00 -4.37190493e+00
1.28997708e+00 -2.38385513e+00 -5.13115557e-01 -2.10031104e+00
1.51949120e+00 -2.84127524e+00 2.46342483e+00 -3.24734083e+00
3.76229476e-01 -2.63336466e+00 -2.59096797e+00 -3.48918738e+00
-4.41519783e-01 1.32295014e+00 1.35556153e+00 2.24474786e+00
2.93086996e+00 -5.20447541e-01 -3.32440933e+00 1.79676223e+00
-3.34203483e+00 -1.66433472e+00 -3.25399070e+00 -1.02745252e+00
1.51664811e+00 1.44390086e+00 1.85848327e+00 -3.89770602e+00
-3.47102160e+00 -1.51369587e+00 -3.67958203e+00 1.51582425e+00
你能帮我找出为什么我的输出有所不同吗?
解决方案
推荐阅读
- mysql - 总结相似的正负值行以在 MySQL 中输出干净的结果
- haskell - 关联类型族抱怨 `pred :: T a -> Bool` 带有“NB:'T' 是一个类型函数,并且可能不是单射的”
- api - 货币的 BigCommerce v3 API 端点
- oracle - 过程在调用时给出空结果
- docker - curl (56) Recv failure: Connection reset by peer - 当点击 docker 容器时
- java - 使用@Transient 注解进行持久化存储
- reactjs - kendo-react-ui 输入错误的 Redux 表单包装器
- python - 如何将 locals() 中的值赋给同名的局部变量?
- c - azure-iot-sdk-c 反序列化 JSON 有效负载
- php - 如何在编辑表单中获取select2中的选定值