首页 > 解决方案 > 删除循环遍历 numpy 数组的要求

问题描述

概述

下面的代码包含一个 numpy 数组clusters,其中的值使用np.where. 该SoFunc函数返回所有条件所在的行True并将clusters数组作为输入。

问题

我可以遍历这个数组,将每个数组元素与各自的np.where条件进行比较。如何删除循环要求但仍获得相同的输出?

我很欣赏循环虽然 numpy 数组效率低下,但我想改进这段代码。实际的数据集会大得多。

准备可重现的模拟数据

def genMockDataFrame(days,startPrice,colName,startDate,seed=None): 

    periods = days*24
    np.random.seed(seed)
    steps = np.random.normal(loc=0, scale=0.0018, size=periods)
    steps[0]=0
    P = startPrice+np.cumsum(steps)
    P = [round(i,4) for i in P]

    fxDF = pd.DataFrame({ 
        'ticker':np.repeat( [colName], periods ),
        'date':np.tile( pd.date_range(startDate, periods=periods, freq='H'), 1 ),
        'price':(P)})
    fxDF.index = pd.to_datetime(fxDF.date)
    fxDF = fxDF.price.resample('D').ohlc()
    fxDF.columns = [i.title() for i in fxDF.columns]
    return fxDF


def SoFunc(clust):
    #generate mock data
    df = genMockDataFrame(10,1.1904,'eurusd','19/3/2020',seed=157)
    df["Upper_Band"] = 1.1928
    df.loc["2020-03-27", "Upper_Band"] = 1.2118
    df.loc["2020-03-26", "Upper_Band"] = 1.2200
    df["Level"] = np.where((df["High"] >= clust)
                                      & (df["Low"] <= clust)
                                     & (df["High"] >= df["Upper_Band"] ),1,np.NaN
                                      )
    return df.dropna()

循环遍历 clusters 数组

clusters = np.array([1.1929   , 1.2118 ])

l = []

for i in range(len(clusters)):
    l.append(SoFunc(clusters[i]))
    
pd.concat(l)

输出

              Open  High    Low    Close    Upper_Band  Level
date                        
2020-03-19  1.1904  1.1937  1.1832  1.1832  1.1928      1.0
2020-03-25  1.1939  1.1939  1.1864  1.1936  1.1928      1.0
2020-03-27  1.2118  1.2144  1.2039  1.2089  1.2118      1.0

标签: pandasnumpy

解决方案


(根据@tdy 下面的评论编辑)

pandas.merge允许您len(clusters)复制数据框,然后根据函数中的条件将其削减SoFunc

df交叉合并为 中的每条记录创建一个带有副本的数据框clusters_df。大数据帧的总体结果应该比基于循环的方法更快,前提是您有足够的内存来临时容纳合并的数据帧(如果没有,操作可能会溢出到页面/交换并大幅减慢)。

import numpy as np
import pandas as pd

def genMockDataFrame(days,startPrice,colName,startDate,seed=None): 
    ''' identical to the example provided '''

    periods = days*24
    np.random.seed(seed)
    steps = np.random.normal(loc=0, scale=0.0018, size=periods)
    steps[0]=0
    P = startPrice+np.cumsum(steps)
    P = [round(i,4) for i in P]

    fxDF = pd.DataFrame({ 
        'ticker':np.repeat( [colName], periods ),
        'date':np.tile( pd.date_range(startDate, periods=periods, freq='H'), 1 ),
        'price':(P)})
    fxDF.index = pd.to_datetime(fxDF.date)
    fxDF = fxDF.price.resample('D').ohlc()
    fxDF.columns = [i.title() for i in fxDF.columns]
    return fxDF
    
# create the base dataframe according to the former SoFunc
df = genMockDataFrame(10,1.1904,'eurusd','19/3/2020',seed=157)
df["Upper_Band"] = 1.1928
df.loc["2020-03-27"]["Upper_Band"] = 1.2118
df.loc["2020-03-26"]["Upper_Band"] = 1.2200

# create a df out of the cluster array
clusters = np.array([1.1929   , 1.2118 ])
clusters_df = pd.DataFrame({"clust": clusters})

# perform the merge, then filter and finally clean up
result_df = (
    pd
    .merge(df.reset_index(), clusters_df, how="cross") # for each entry in cluster, make a copy of df
    .loc[lambda z: (z.Low <= z.clust) & (z.High >= z.clust) & (z.High >= z.Upper_Band), :] # filter the copies down
    .drop(columns=["clust"]) # not needed in result
    .assign(Level=1.0) # to match your result; not really needed
    .set_index("date") # bring back the old index
)

print(result_df)

我建议只检查结果,pd.merge(df.reset_index(), clusters_df, how="cross")看看它是如何工作的。


推荐阅读