首页 > 解决方案 > 如何执行 numpy 数组的未排序比较?

问题描述

在使用 测试线路检测的输出时python-opencv,我在执行不符合我期望的“未排序”测试时遇到了麻烦。从一开始(为了避免 XY 问题):

给定一个数组 ,a它看起来像:

import numpy as np                                                              

a = np.array([
    # x0, y0, x1, y1                                                           
    [[6, 263, 6, 84]],
    # x0, y0, x1, y1                                                        
    [[0, 92, 181, 4]]                                                           
])

我希望这个比较大约等于target

target = np.array([                                                             
    [[7, 86, 5, 263]],                                                          
    [[1, 91, 182, 4]],                                                          
])

注意每组坐标代表一条线的起点和终点;两者的顺序无关紧要,无论这个顺序如何,我都希望数组测试“相等”。

我目前解决这个问题的方法包括:

  1. 将 4 个点的子数组拆分为 2 个点中的两个:
def reformat(arr):                                                     
    new_arr = []                                                                  
    for row in arr:                                                             
        x0, y0, x1, y1 = row[0]                                                 
        new_arr.append([[x0, y0], [x1, y1]])                                      
    return np.array(new_arr)                                                                                                                             

a = reformat(a)                                                                          
print(a)

这按预期工作并输出:

[[[  6 263]
  [  6  84]]

 [[  0  92]
  [181   4]]]
  1. 对重新格式化的数组进行排序:
print(np.sort(a, axis=1))   

这输出:

[[[  6  84]
  [  6 263]]

 [[  0   4]
  [181  92]]]

而所需的输出是:

[[[  6  84]
  [  6 263]]

 [[  0   92]
  [181  4]]]

即坐标对保持不变,但对每一行进行词法排序。

最终,我将使用以下方法实现我的测试:

np.testing.assert_allclose(                                                     
    reformat_and_sort(a),                                                       
    reformat_and_sort(target),                                                  
    atol=5                                                                      
)

我怎样才能重新格式化和排序atarget这样np.testing.assert_allclose就不会引发AssertionError?

标签: pythonnumpy

解决方案


这个函数应该这样做:

def sort_lines(arr):
    for k, line in enumerate(arr):
        arr[k] = line[np.argsort(line[1,:])[::-1]]
    return arr

所以,

print(sort_lines(reformat(a)))

输出:

[[[  6  84]
  [  6 263]]

 [[  0  92]
  [181   4]]]

推荐阅读