首页 > 解决方案 > 具有多个元素的数组的真值不明确?

问题描述

我遇到了无法对数组进行排序的问题。我收到此错误The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()为什么会发生这种情况?我不明白。是不是因为破局?我一直在这个问题上一段时间,无法弄清楚。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn import datasets, metrics, svm
from sklearn.model_selection import train_test_split
from collections import Counter
from math import sqrt

#import number data
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images, digits.target))
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
data_train, data_test, label_train, label_test = train_test_split(data, digits.target, test_size=0.2)

def euclidean_distance(first, second):
    distance = 0.0
    for i in range(64):
        distance += (first[i] - second[i])**2
    return np.sqrt(distance)

def get_neighbors(train_set, test_set, num_neighbors):
    distances = list()
    for test_set in train_set:
        dist = euclidean_distance(test_set, train_set)
        distances.append((train_set, dist))
    np.sort(distances)
    neighbors = list()
    for i in range(num_neighbors):
        neighbors.append(distances[i][0])
    return neighbors

results = get_neighbors(data_train, data_test, 100 )

标签: pythonnumpymachine-learning

解决方案


在这一行:

np.sort(distances)

distances 是一个元组列表——每个元组都包含一对 numpy 数组。例如:

>>> distances[0]
(array([[ 0.,  0.,  6., ...,  0.,  0.,  0.],
        [ 0.,  3., 13., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ..., 16.,  7.,  0.],
        ...,
        [ 0.,  0.,  1., ..., 16.,  5.,  0.],
        [ 0.,  1., 13., ...,  1.,  0.,  0.],
        [ 0.,  0.,  1., ..., 14.,  6.,  0.]]),
 array([54.40588203, 51.7107339 , 58.72818744, 83.80930736, 77.37570678,
        58.25804665, 54.18486874, 54.47935389, 54.40588203, 52.13444159,
        73.54590403, 87.35559513, 79.01898506, 66.55824517, 54.61684722,
        54.56189146, 54.40588203, 50.55689864, 78.65748534, 74.60562981,
        74.37741593, 72.70488292, 55.08175742, 54.41507144, 54.40588203,
        49.43682838, 76.51143705, 70.9577339 , 75.94076639, 65.90902821,
        56.90342696, 54.40588203, 54.40588203, 55.06359959, 73.06161783,
        73.71566998, 82.50454533, 70.        , 54.        , 54.40588203,
        54.40588203, 53.10367219, 75.16648189, 78.7273777 , 74.06753675,
        67.00746227, 55.6596802 , 54.40588203, 54.40588203, 52.5832673 ,
        68.82586723, 80.74651695, 74.24957912, 72.9725976 , 56.59505279,
        53.8702144 , 54.40588203, 52.06726419, 61.10646447, 83.5463943 ,
        84.92938243, 61.83041323, 53.38539126, 54.2678542 ]))

错误是因为np.sort不知道如何处理。


推荐阅读