python - 使用 numpy 创建具有任意形状的单位矩阵
问题描述
是否有更快/内置的方法来生成第一维中具有任意形状的单位矩阵和最后一m
维中的单位?
import numpy as np
base_shape = (10, 11, 12)
n_dim = 4
# m = 2
frames2d = np.zeros(base_shape + (n_dim, n_dim))
for i in range(n_dim):
frames2d[..., i, i] = 1
# m = 3
frames3d = np.zeros(base_shape + (n_dim, n_dim, n_dim))
for i in range(n_dim):
frames3d[..., i, i, i] = 1
解决方案
方法#1
我们可以利用np.einsum
受启发的对角线视图this post
,从而将其分配1s
给我们想要的输出。因此,对于这种m=3
情况,在用零初始化之后,我们可以简单地做 -
diag_view = np.einsum('...iii->...i',frames3d)
diag_view[:] = 1
概括包括这些输入参数,它将是 -
def ndeye_einsum(base_shape, n_dim, m):
out = np.zeros(list(base_shape) + [n_dim]*m)
diag_view = np.einsum('...'+'i'*m+'->...i',out)
diag_view[:] = 1
return out
因此,要重现这些相同的数组,它将是 -
frames2d = ndeye_einsum(base_shape, n_dim, m=2)
frames3d = ndeye_einsum(base_shape, n_dim, m=3)
方法#2
同样,从同一个链接的帖子中,我们也可以重塑为 2D 并沿着 cols 分配到步进大小的切片数组,就像这样 -
def ndeye_reshape(base_shape, n_dim, m):
N = (n_dim**np.arange(m)).sum()
out = np.zeros(list(base_shape) + [n_dim]*m)
out.reshape(-1,n_dim**m)[:,::N] = 1
return out
这再次适用于视图,因此应该与方法#1同样有效。
方法#3
另一种方法是使用基于整数的索引。因此,例如,对于frames3d
一次性分配,它将是 -
I = np.arange(n_dim)
frames3d[..., I, I, I] = 1
概括成为 -
def ndeye_ellipsis_indexer(base_shape, n_dim, m):
I = np.arange(n_dim)
indexer = tuple([Ellipsis]+[I]*m)
out = np.zeros(list(base_shape) + [n_dim]*m)
out[indexer] = 1
return out
扩展到更高的视野
沿 base_shape 的暗淡基本上是来自最后暗淡的元素的复制m
。因此,我们可以使用np.broadcast_to
. 我们将创建一个基本的 m-dim 标识数组,然后将视图广播到更高的维度。这将适用于之前发布的所有三种方法。为了演示如何在einsum
基于解决方案上使用它,我们将拥有 -
# Create m-dim "trailing-base" array, basically a m-dim identity array
def ndeye_einsum_trailingbase(n_dim, m):
out = np.zeros([n_dim]*m)
diag_view = np.einsum('i'*m+'->...i',out)
diag_view[:] = 1
return out
def ndeye_einsum_view(base_shape, n_dim, m):
trail_base = ndeye_einsum_trailingbase(n_dim, m)
return np.broadcast_to(trail_base, list(base_shape) + [n_dim]*m)
因此,我们将再次拥有,例如 -
frames3d = ndeye_einsum_view(base_shape, n_dim, m=3)
这将是一个 m-dim 数组的视图,因此在内存和性能上都很有效。
推荐阅读
- python - 每当尝试将 write_videofile 运行到moviepy中的剪辑时,都会出现“TypeError:必须是实数,而不是 NoneType”
- docker - 我的 package.json bundleDependencies 会自动安装到 docker 镜像中的最新版本
- javascript - 获取 div id 作为数组?
- java - 有没有办法通过 API 或 DB 区分 SYSML 序列图中的 Lifeline 和 Other 元素?
- reactjs - `无效的钩子调用反应`
- java - ModelCategory 模型 = dataSnapshot.getValue(ModelCategory.class); 应用获取 Crush
- python - 使用 asyncio/aiohttp 进行多个 Websocket 流式传输
- java - 当我在本地浏览器上请求时,如何将数据库作为 json 值访问?
- r - 不同 bin 分组的直方图
- node.js - NGINX 和 Vuejs 提供静态文件