python - Numpy:将行值广播到频道
问题描述
我有一个数据集,其中前 48 个观察值是时间序列,其他 12 个是静态变量:
h1 h2 h3 h4 ... h48 v1 v2 v3 v4 v5 v6 .. vn
h1 h2 h3 h4 ... h48 v1 v2 v3 v4 v5 v6 .. vn
一件物品的形状是(367, 60)
。
我想将变量v1 v2 v3 v4 v5 v6 .. vn
作为附加通道传递给时间序列,即创建形状数组(367, 48, 13)
。我想即时进行,因为完全转换的数据集不适合我的 RAM。
我现在使用的代码效率很低(items
是批处理):
def preprocessor(items):
items_new = np.zeros(shape=(items.shape[0], 367, 48, 13), dtype=np.float32)
for idx_item, item in enumerate(items):
train_data = item[:,:48]
train_vars = item[:,48:]
train_new = np.zeros((train_data.shape[0], train_data.shape[1],(train_vars.shape[1]+1)))
for idx_row, row in enumerate(train_data):
for idx_col, elem in enumerate(row):
train_new[idx_row, idx_col, :] = np.concatenate([[elem], train_vars[idx_row]])
items_new[idx_item] = train_new
return items_new
我可以在没有循环的情况下更快地完成它吗?
编辑:
最小的可重现示例:
arr = np.random.randn(5,367,60)
arr2 = preprocessor(arr)
print(arr2.shape) # (5, 367, 48, 13)
解决方案
方法#1
我们可以将广播数组分配用于矢量化解决方案 -
def array_assign(items):
L = 48 # slice at this column ID
N = items.shape[-1]
out = np.empty(shape= items.shape[:2] + (L,N-L+1), dtype=np.float32)
out[...,1:] = items[...,None,L:]
out[...,0] = items[...,:L]
return out
方法#2
我们还可以使用广播视图,然后连接 -
def broadcast_concat(items):
L = 48 # slice at this column ID
N = items.shape[-1]
a = items[...,:L,None]
shp_b = items.shape[:2] + (L,N-L)
b = np.broadcast_to(items[...,None,L:],shp_b)
out = np.concatenate((a,b),axis=-1)
return out
计时 -
In [321]: items = np.random.rand(5,367,60)
In [322]: %timeit array_assign(items)
1000 loops, best of 3: 923 µs per loop
In [323]: %timeit broadcast_concat(items)
1000 loops, best of 3: 781 µs per loop
为了公平比较,我们应该让第二种方法也使用更有效的float32
dtype。让我们使用该 dtype 来设置输入数据并再次测试 -
In [335]: items = np.random.rand(5,367,60).astype(np.float32)
In [336]: %timeit array_assign(items)
1000 loops, best of 3: 897 µs per loop
In [337]: %timeit broadcast_concat(items)
1000 loops, best of 3: 348 µs per loop
因此,对于需要 dtype 转换的情况下的大多数性能,我们可以items = np.asarray(items, dtype=np.float32)
在方法 #2 开始时使用。
推荐阅读
- linux - 构建 docker:未找到 opt/conda/bin/conda
- javascript - 如何将我的值附加到谷歌电子表格 OAUTH/GAPI js
- amazon-web-services - AWS 中的密码格式的秘密
- python - KivyMD - MDRectangleFlatButton - 按钮阴影动画太慢
- css - 如何在不使用绝对定位的情况下使用内联伪元素将元素包裹在内联元素周围
- reactjs - 反应属性映射问题
- node.js - Azure 管道中的数据库迁移
- sql-server - 如何仅检索 SQL Server 数据库模型的用户定义属性
- angular - Angular 指令从 NgModel 或 FormControlName 获取值
- php - 如何获取数组的每个元素