首页 > 解决方案 > 如何在熊猫中使用多列向后应用滚动功能?

问题描述

鉴于这个简单的数据框:

df = pd.DataFrame(np.random.randint(0,100, size=(50, 4)), columns=list('ABCD'))

我正在尝试执行以下计算:

  1. 添加三个名为 B1、C2 和 D2 的列,默认填充为NaN
  2. 逐一检查A列接下来的5行,第一个大于20的,然后B1、C2和D2列将填充该特定行的B、C和D列的内容。
  3. 如果 A 列接下来的 5 行都不小于 20,则 B1、C2 和 D2 列将保持为NaN

我想出了这种方法:

def check_thresh(ser):
    dft = df.loc[ser.index]
    
    for _, row in dft.iterrows():
        if row['A'] > 20:
            return np.array([row['B'], row['C'], row['D']])
        
    return np.array([np.nan, np.nan, np.nan])

rol = df['A'].rolling(window=5)
df[['B1', 'C1', 'D1']] = rol.apply(check_thresh, raw=False)

但是,我面临以下问题:

  1. 它检查5 行,而不是接下来的5 行。
  2. 性能很慢,我必须使用大型数据集。
  3. 它返回以下错误:TypeError: only size-1 arrays can be converted to Python scalars将滚动函数应用于新列时。

我的方法有什么问题?你知道有更好的方法来处理这种情况吗?

标签: pythonpandasdataframenumpyrolling-computation

解决方案


我不确定此实现是否经过优化,或者是否正确,因为我没有完全理解这个问题并且没有预期输出的示例。

from numpy.lib.stride_tricks import sliding_window_view

WINDOWSIZE = 5
THRESHOLD = 20

# Equivalent to pd.rolling
m = sliding_window_view(df, (WINDOWSIZE, len(df.columns))).squeeze().astype(float)

# Extract 'A' column
A = m[:, :, 0]

# Get the first index whose value > THRESHOLD
argm = np.argmax(A > THRESHOLD, axis=1)

# True if all values <= THRESHOLD
amin = np.amin(A <= THRESHOLD, axis=1)

# Select rows in original array m
r = np.take_along_axis(m, argm[:, np.newaxis, np.newaxis], axis=1).squeeze()
r[amin] = np.nan

例子:

>>> df
    A   B   C   D
0   0   1   2   3
1   4   5   6   7
2   8   9  10  11
3  12  13  14  15
4  16  17  18  19
5  20  21  22  23
6  24  25  26  27
7  28  29  30  31
8  32  33  34  35
9  36  37  38  39
# df1 = pd.DataFrame(A).rename(columns='A{}'.format).assign(argm=argm, amin=amin)
# df2 = pd.DataFrame(r, columns=['A', 'B1', 'C1', 'D1'])

>>> pd.concat([df1, df2], axis='columns')
     A0    A1    A2    A3    A4  argm   amin     A    B1    C1    D1
0   0.0   4.0   8.0  12.0  16.0     0   True   NaN   NaN   NaN   NaN
1   4.0   8.0  12.0  16.0  20.0     0   True   NaN   NaN   NaN   NaN
2   8.0  12.0  16.0  20.0  24.0     4  False  24.0  25.0  26.0  27.0
3  12.0  16.0  20.0  24.0  28.0     3  False  24.0  25.0  26.0  27.0
4  16.0  20.0  24.0  28.0  32.0     2  False  24.0  25.0  26.0  27.0
5  20.0  24.0  28.0  32.0  36.0     1  False  24.0  25.0  26.0  27.0        

注意:最终数据帧的长度为len(df) - WINDOWSIZE + 1


推荐阅读