首页 > 解决方案 > np.where 中的 numpy 广播

问题描述

我的问题是,如何在np.where使用多个条件/输出时广播值而不必依赖乘法?

输入:

import pandas as pd
df = pd.DataFrame({'test':range(0,10)})

   test
0     0
1     1
2     2
3     3
4     4
5     5
6     6
7     7
8     8
9     9

预期输出:

   test  column1  column2
0     0        2        4
1     1        2        4
2     2        2        4
3     3        2        4
4     4        1        3
5     5        1        3
6     6        1        3
7     7        1        3
8     8        1        3
9     9        1        3

我的(工作)代码:

mask  = df['test'] > 3
m_len = len(mask)

df['column1'], df['column2'] = np.where([mask, mask], [[1]*m_len, [3]*m_len], [[2]*m_len, [4]*m_len])

问题:

通常np.where()接受一个数组和一个静态值,例如:

np.where(mask, 1, 2) # where mask is a series

如果我现在使用它,我的期望是:

np.where([mask, mask], [1, 3], [2, 4])

它会广播这个值。

但我收到以下错误:

ValueError: operands could not be broadcast together with shapes (2,10) (2,) (2,) 

有没有办法广播值而不必使用m_len变量(如我的工作代码所示)?

注意:我知道我可以np.where在多行中多次使用,但我想用那一行来解决它。

标签: pythonnumpy

解决方案


如果您将输入的值的形状设置为(2, 1),它将广播。因此,这是一种方法np.r_

df[["col1", "col2"]] = np.where(mask, np.r_["c", 1, 3], np.r_["c", 2, 4]).T

最后一个T是需要的,因为np.where将返回-(2, -1)形数组,但 pandas 期望(-1, 2)它的两列。


如果两个掩码相同,我们也可以只给出一个mask,因为它也会广播它:

mask   ->  (10,)
values ->  (2, 1)

然后

mask'  ->  (1, 10)
values ->  (2, 1)

最后

mask''  ->  (2, 10)
values' ->  (2, 10)

推荐阅读