首页 > 解决方案 > 使用 np.where() 选择数组的有效方法?

问题描述

我有三个数组:mgrad1grad2m是 形状(x,)grad1grad2是 形状(x,y,z)。我试图找出最有效的方法来创建一个新数组,其中的值grad1条目grad2m. 我尝试使用以下代码执行此操作:

param0_grad = np.where(m[:] > 0, grad1, grad2)

根据我对 的理解,np.where()我认为这应该填充param0_grad或基于. 但是,我收到以下广播错误(当 x=3、y=4、z=2 时):grad1grad2m

ValueError: operands could not be broadcast together with shapes (3,) (3,4,2) (3,4,2)

该代码适用于 x=2,但没有 x>2 的值。

标签: pythonarraysarray-broadcasting

解决方案


尝试这个:

param0_grad = np.where(m[:,None,None] > 0, grad1, grad2)

基本上,您需要添加一个空维度来广播尾轴。


推荐阅读