首页 > 解决方案 > 根据一维条件删除 ndarray 中的元素

问题描述

在 Numpy ndarray 中,如何根据不同维度中的条件删除维度中的元素?

我有:

[[[1 3]
  [1 4]]

 [[2 6]
  [2 8]]

 [[3 5]
  [3 5]]]

我想根据条件删除x[:,:,1] < 7

所需的输出([:,1,:]已删除):

[[[1 3]
  [1 4]]

 [[3 5]
  [3 5]]]

编辑:修正错字

标签: pythonnumpynumpy-ndarray

解决方案


这可能有效:

x[np.where(np.all(x[..., 1] < 7, axis=1)), ...]

产量

array([[[[1, 3],
         [1, 4]],

        [[3, 5],
         [3, 5]]]])

你确实得到了一个额外的维度,但这很容易删除:

np.squeeze(x[np.where(np.all(x[..., 1] < 7, axis=1)), ...])

简要介绍它的工作原理:

首先条件:x[..., 1] < 7
然后测试条件是否对沿特定轴的所有元素都有效:np.all(x[..., 1] < 7, axis=1).
然后,使用where获取索引而不是布尔数组:np.where(np.all(x[..., 1] < 7, axis=1))
并将这些索引插入相关维度:x[np.where(np.all(x[..., 1] < 7, axis=1)), ...].


推荐阅读