首页 > 解决方案 > Numpy中没有for循环的两个数组之间的元素比较

问题描述

我有一个名为dataset的大数组,尺寸为 Numpy(700、28、28、3)。假设这个矩阵如下所示:

>>> dataset=np.random.rand(5600,28,28,3)
>>> dataset.shape
(5600, 28, 28, 3)

现在,假设我有另一个更简单的数组,称为查询,我将使用它在数据集数组中搜索

>>> query=np.random.rand(28,28,3)
>>> query.shape
(28, 28, 3)

在较大的矩阵查询中搜索该矩阵查询的一种方法是计算它与数组dataset的所有元素之间的均方误差。较小的 MSE 告诉我矩阵在数组数据集中的位置。

问题是,我不想在 Python 中创建一个 for 循环来逐个计算 MSE,将 MSE 存储在另一个数组中,然后在循环结束时获取最小 MSE 的位置。在此比较之前,我已经有两个 for 循环,因此,我希望使其尽可能高效和快速。没有大的for循环是否可以解决这样的问题?

标签: pythonarraysnumpy

解决方案


你可以这样做:

se = (dataset-query)**2                            # Squared error - shape (L,28,28,3)
sum_of_se = np.sum(se.reshape(-1,28*28*3), axis=1) # Sum of squared error - shape (L,)
print (np.argmin(sum_of_se))                       # Position of minimum within sum_of_se

推荐阅读