python-3.x - 在 KNeighborsClassifier 中使用自定义指标时,我不断收到“TypeError:只有整数标量数组可以转换为标量索引”
问题描述
我在 SKlearn 的 KNeighborsClassifier 中使用自定义指标。这是我的代码:
def chi_squared(x,y):
return np.divide(np.square(np.subtract(x,y)), np.sum(x,y))
以上是卡方距离函数的函数实现。我使用了 NumPy 函数,因为根据scikit-learn docs,度量函数采用两个一维 numpy 数组。
我已将 chi_squared 函数作为参数传递给 KNeighborsClassifier()。
knn = KNeighborsClassifier(algorithm='ball_tree', metric=chi_squared)
但是,我不断收到以下错误:
TypeError Traceback (most recent call last)
<ipython-input-29-d2a365ebb538> in <module>
4
5 knn = KNeighborsClassifier(algorithm='ball_tree', metric=chi_squared)
----> 6 knn.fit(X_train, Y_train)
7 predictions = knn.predict(X_test)
8 print(accuracy_score(Y_test, predictions))
~/.local/lib/python3.8/site-packages/sklearn/neighbors/_classification.py in fit(self, X, y)
177 The fitted k-nearest neighbors classifier.
178 """
--> 179 return self._fit(X, y)
180
181 def predict(self, X):
~/.local/lib/python3.8/site-packages/sklearn/neighbors/_base.py in _fit(self, X, y)
497
498 if self._fit_method == 'ball_tree':
--> 499 self._tree = BallTree(X, self.leaf_size,
500 metric=self.effective_metric_,
501 **self.effective_metric_params_)
sklearn/neighbors/_binary_tree.pxi in sklearn.neighbors._ball_tree.BinaryTree.__init__()
sklearn/neighbors/_binary_tree.pxi in sklearn.neighbors._ball_tree.BinaryTree._recursive_build()
sklearn/neighbors/_ball_tree.pyx in sklearn.neighbors._ball_tree.init_node()
sklearn/neighbors/_binary_tree.pxi in sklearn.neighbors._ball_tree.BinaryTree.rdist()
sklearn/neighbors/_dist_metrics.pyx in sklearn.neighbors._dist_metrics.DistanceMetric.rdist()
sklearn/neighbors/_dist_metrics.pyx in sklearn.neighbors._dist_metrics.PyFuncDistance.dist()
sklearn/neighbors/_dist_metrics.pyx in sklearn.neighbors._dist_metrics.PyFuncDistance._dist()
<ipython-input-29-d2a365ebb538> in chi_squared(x, y)
1 def chi_squared(x,y):
----> 2 return np.divide(np.square(np.subtract(x,y)), np.sum(x,y))
3
4
5 knn = KNeighborsClassifier(algorithm='ball_tree', metric=chi_squared)
<__array_function__ internals> in sum(*args, **kwargs)
~/.local/lib/python3.8/site-packages/numpy/core/fromnumeric.py in sum(a, axis, dtype, out, keepdims, initial, where)
2239 return res
2240
-> 2241 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
2242 initial=initial, where=where)
2243
~/.local/lib/python3.8/site-packages/numpy/core/fromnumeric.py in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
85 return reduction(axis=axis, out=out, **passkwargs)
86
---> 87 return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
88
89
TypeError: only integer scalar arrays can be converted to a scalar index
解决方案
我可以通过以下方式重现您的错误消息:
In [173]: x=np.arange(3); y=np.array([2,3,4])
In [174]: np.sum(x,y)
Traceback (most recent call last):
File "<ipython-input-174-1a1a267ebd82>", line 1, in <module>
np.sum(x,y)
File "<__array_function__ internals>", line 5, in sum
File "/usr/local/lib/python3.8/dist-packages/numpy/core/fromnumeric.py", line 2247, in sum
return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
File "/usr/local/lib/python3.8/dist-packages/numpy/core/fromnumeric.py", line 87, in _wrapreduction
return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
TypeError: only integer scalar arrays can be converted to a scalar index
正确使用np.sum
:
In [175]: np.sum(x)
Out[175]: 3
In [177]: np.sum(np.arange(6).reshape(2,3), axis=0)
Out[177]: array([3, 5, 7])
In [178]: np.sum(np.arange(6).reshape(2,3), 0)
Out[178]: array([3, 5, 7])
np.sum
必要时(重新)阅读文档!
使用np.add
而不是np.sum
:
In [179]: np.add(x,y)
Out[179]: array([2, 4, 6])
In [180]: x+y
Out[180]: array([2, 4, 6])
以下应该是等效的:
np.divide(np.square(np.subtract(x,y)), np.add(x,y))
(x-y)**2/(x+y)
推荐阅读
- shell - '$cd' 不是内部或外部命令、可运行程序或批处理文件错误消息
- javascript - 有 1 行条件更新 Javascript 中的其他变量
- reactjs - Ionic React + Electron Discord 认证
- mongodb - 无法在管道 $lookup 中获取具有 $all 的文档
- c# - 连续 WebJob 不启动
- c++ - 如何从当前迭代器位置获取字符串?
- php - ON DUPLICATE KEY UPDATE 多行,而不是列
- python - 我应该明确关闭 asyncio.create_subprocess_shell 进程吗?
- react-native - React Native 在本地存储用户数据
- api - API 通过 URL 将照片发送到电报机器人