python - 2D numpy 数组组上最快的应用函数
问题描述
假设我有一个二维数据矩阵,我想将一个函数应用于该矩阵中的组。
例如:
对于每个唯一索引,我想应用一些功能f
。
例如,对于具有index = 1
函数的组,将f
应用于值0.556, 0.492, 0.148
(见第一列),对于组index = 2
,将函数应用于值0.043
。
此外:
- 该函数必须将结果广播到输入数据的原始大小。
- 这些组每列都是唯一的。您可以在上面的示例中看到这一点,其中每个组仅包含同一列中的值。
那么在 Python 中执行此操作的绝对最快的方法是什么?
我目前正在执行以下操作(随机数据 [2000x500] 和每列 5 个随机组):
import numpy as np
rows = 2000
cols = 500
ngroup = 5
data = np.random.rand(rows,cols)
groups = np.random.randint(ngroup, size=(rows,cols)) + 10*np.tile(np.arange(cols),(rows,1))
result = np.zeros(data.shape) # Pre-allocating the result
f = lambda x: (x-np.average(x))/np.std(x) # The function I want to apply
for group in np.unique(groups): # Loop over every unique group
location = np.where(groups == group) # Find the location of the data
group_data = data[location[0],location[1]] # Get the data
result[location[0],location[1]] = f(group_data) # Apply the function
使用我的硬件,这个计算大约需要 10 秒才能完成。有没有更快的方法来做到这一点?
解决方案
不确定是否是最快的,但这个矢量化解决方案要快得多:
import numpy as np
import time
np.random.seed(0)
rows = 2000
cols = 500
ngroup = 5
data = np.random.rand(rows,cols)
groups = np.random.randint(ngroup, size=(rows,cols)) + 10*np.tile(np.arange(cols),(rows,1))
t = time.perf_counter()
# Flatten the data
dataf = data.ravel()
groupsf = groups.ravel()
# Sort by group
idx_sort = groupsf.argsort()
datafs = dataf[idx_sort]
groupsfs = groupsf[idx_sort]
# Find group bounds
idx = np.nonzero(groupsfs[1:] > groupsfs[:-1])[0]
idx = np.concatenate([[0], idx + 1, [len(datafs)]])
# Sum by groups
a = np.add.reduceat(datafs, idx[:-1])
# Count group elements
c = np.diff(idx)
# Compute group means
m = a / c
# Repeat means and counts to match data shape
means = np.repeat(m, c)
counts = np.repeat(c, c)
# Compute variance and std
v = np.add.reduceat(np.square(datafs - means), idx[:-1]) / c
s = np.sqrt(v)
# Repeat stds
stds = np.repeat(s, c)
# Compute result values
resultfs = (datafs - means) / stds
# Undo sorting
idx_unsort = np.empty_like(idx_sort)
idx_unsort[idx_sort] = np.arange(len(idx_sort))
resultf = resultfs[idx_unsort]
# Reshape back
result = np.reshape(resultf, data.shape)
print(time.perf_counter() - t)
# 0.09932469999999999
# Previous method to check result
t = time.perf_counter()
result_orig= np.zeros(data.shape)
f = lambda x: (x-np.average(x))/np.std(x)
for group in np.unique(groups):
location = np.where(groups == group)
group_data = data[location[0],location[1]]
result_orig[location[0],location[1]] = f(group_data)
print(time.perf_counter() - t)
# 6.0592527
print(np.allclose(result, result_orig))
# True
编辑:要计算中位数,您可以执行以下操作:
# Flatten the data
dataf = data.ravel()
groupsf = groups.ravel()
# Sort by group and value
idx_sort = np.lexsort((dataf, groupsf))
datafs = dataf[idx_sort]
groupsfs = groupsf[idx_sort]
# Find group bounds
idx = np.nonzero(groupsfs[1:] > groupsfs[:-1])[0]
idx = np.concatenate([[0], idx + 1, [len(datafs)]])
# Count group elements
c = np.diff(idx)
# Meadian index
idx_median1 = c // 2
idx_median2 = idx_median1 + (c % 2) - 1
idx_median1 += idx[:-1]
idx_median2 += idx[:-1]
# Get medians
meds = 0.5 * (datafs[idx_median1] + datafs[idx_median2])
这里的技巧是使用np.lexsort
而不是仅仅np.argsort
按组和值排序。meds
将是一个包含每个组的中位数的数组,然后您可以使用np.repeat
它,就像使用手段一样,或者您想要的任何其他东西。
推荐阅读
- xml - 无法打开我用 protégé 手动实现的本体
- django-rest-framework - 如何将 Django REST Framework 路由器中的 url 参数限制为整数?
- windows - 一段时间后,pg_notify 不会向在 docker 容器中运行的应用程序发送通知
- node.js - 使用等待后哈希密码不起作用
- ontology - 如何将数据(实例)导入到 protoge 中的现有本体中
- java - 如何在 Spring Boot 应用程序启动期间将 SynchronizationCallbacks 添加到 @TransactionalEventListener?
- reactjs - 使用 useEffect 时,我在传递小于变量依赖项时收到警告
- math - 在 Maple 中,如何将数组(列表或向量或矩阵)的数量与行数和列数相等的数组的数量相除?
- javascript - firebase firestore 离线持久化 FirebaseError
- maven - 无法在 Maven 部署 github 操作上重命名战争或 jar 文件