首页 > 解决方案 > 为每个用户组使用 for 循环时提高速度

问题描述

假设我们有以下带有输出 window_num 的数据集:

index    user1   date    different_months   org_different_months    window_num
1690289 2670088 2006-08-01  243.0                243.0                  1
1772121 2717874 2005-12-01  0.0                   0.0                   1
1772123 2717874 2005-12-01  0.0                   0.0                   1
1772125 2717874 2005-12-01  0.0                   0.0                   1
1772130 2717874 2005-12-01  0.0                   0.0                   1
1772136 2717874 2006-01-01  0.0                   0.0                   1
1772132 2717874 2006-02-01  0.0                  2099.0                 1
1772134 2717874 2020-08-27  0.0                   0.0                   4
1772117 2717874             0.0                   0.0                   4
1772118 2717874             0.0                   0.0                   4
1772128 2717874 2019-11-01  300.0                300.0                  3
1772127 2717874 2011-11-01  2922.0               2922.0                 2
1774815 2719456 2006-09-01  0.0                   0.0                   2
1774809 2719456 2006-10-01  0.0                  1949.0                 2
1774821 2719456 2020-05-20  0.0                   0.0                   7
1774803 2719456             0.0                   0.0                   7
1774806 2719456             0.0                   0.0                   7
1774819 2719456 2019-08-29  265.0                265.0                  6
1774825 2719456 2014-10-01  384.0                384.0                  4
1774812 2719456 2005-07-01  427.0                427.0                  1
1774816 2719456 2012-02-01  973.0                973.0                  3
1774824 2719456 2015-10-20  1409.0               1409.0                 5

用户编号由 user1 表示。输出是使用 different_months 和 orig_different_months 列生成的 window_num。different_months 列是 date[n] 和 date[n+1] 之间的差异。

以前,我使用 groupby.apply 来输出 window_num,但是当数据集增加时它变得非常慢。通过在整个数据集上使用移位函数来计算 different_months 和 orig_different_months 列,以及对整个数据集应用排序,代码得到了显着改进,如下所示:

  data = data.sort_values(by=['user','ContractInceptionDateClean'], ascending=[True,True])
  #data['user1'] =data['user']
  data['different_months'] = (abs((data['ContractInceptionDateClean'].shift(-1)-data['ContractInceptionDateClean'] ).dt.days)).fillna(0)
  data.different_months[data['different_months'] < 91] =0
  data['shift_different_months']=data['different_months'].shift(1)
  data['org_different_months']=data['different_months']
  
  data.loc[((data['different_months'] == 0) | (data['shift_different_months'] == 0)),'different_months']=0
  data = salesswindow_cal(data,list(data.user.unique()))

我目前正在努力提高速度的代码如下所示:

def salesswindow_cal(data_,users):
    temp = pd.DataFrame()
    for u in range(0,len(users)):
        df=data_[data_['user']==users[u]]      
        df['different_months'].values[0]= df['org_different_months'].values[0]          
        df['window_num']=(df['different_months'].diff() != 0).cumsum()
        temp= pd.concat([df,temp],axis=0)        
    return pd.DataFrame(temp)

标签: python-3.xpandasdataframe

解决方案


经验法则是不要遍历用户并提取df = data_[data_['user']==user]. 而是这样做groupby

for u, df in data_.gropuby('user'):
    do_some_stuff

另一个问题是不要迭代地连接数据

data_out = []
for user, df in data.groupby('user'):
    do_some_stuff
    data_out.append(sub_data)

out = pd.concat(data_out)

在您的情况下,您可以执行一个函数,groupby().apply()并且 pandas 将为您连接数据。

def group_func(df):
    d = df.copy()
    d['different_months'].values[0] = d['org_different_months'].value[0]
    d['window_num'] = (d['different_months'].diff().ne(0).cumsum()

    return d

data.groupby('user').apply(group_func)

更新

让我们试试这种矢量化方法,它可以就地修改您的数据

# update the first `different_months`
mask = ~data['user'].duplicated()
data.loc[mask, 'different_months'] == data.loc[mask, 'orginal_different_months']


groups = data.groupby('user')
data['diff'] = groups['different_months'].diff().ne(0)
data['window_num'] = groups['diff'].cumsum()

    

推荐阅读