首页 > 解决方案 > 如何检查 numpy 数组列表是否包含给定的测试数组?

问题描述

我有一个numpy数组列表,比如说,

a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]

我有一个测试数组,比如说

b = np.random.rand(3, 3)

我想检查是否a包含b。然而

b in a 

引发以下错误:

ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()

我想要的正确方法是什么?

标签: pythonnumpy

解决方案


您可以只制作一组形状(3, 3, 3)a

a = np.asarray(a)

然后将其与b(我们在这里比较浮点数,所以我们应该使用isclose()

np.all(np.isclose(a, b), axis=(1, 2))

例如:

a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
a = np.asarray(a)
b = a[1, ...]       # set b to some value we know will yield True

np.all(np.isclose(a, b), axis=(1, 2))
# array([False,  True, False])

推荐阅读