python - 使用取决于先前值的操作矢量化 numpy 代码
问题描述
以下代码模拟了一个可以随时采样 3 种不同状态的系统,这些状态之间的恒定转移概率由矩阵 给出prob_nor
。因此,每个点都trace
取决于先前的状态。
n_states, n_frames = 3, 1000
state_val = np.linspace(0, 1, n_states)
prob = np.random.randint(1, 10, size=(n_states,)*2)
prob[np.diag_indices(n_states)] += 50
prob_nor = prob/prob.sum(1)[:,None] # transition probability matrix,
# row sum normalized to 1.0
state_idx = range(n_states) # states is a list of integers 0, 1, 2...
current_state = np.random.choice(state_idx)
trace = []
sigma = 0.1
for _ in range(n_frames):
trace.append(np.random.normal(loc=state_val[current_state], scale=sigma))
current_state = np.random.choice(state_idx, p=prob_nor[current_state, :])
上面代码中的循环使它运行得很慢,特别是当我必须对数百万个数据点进行建模时。有没有办法矢量化/加速它?
解决方案
尽快卸载概率计算:
possible_paths = np.vstack(
np.random.choice(state_idx, p=prob_nor[curr_state, :], size=n_frames)
for curr_state in range(n_states)
)
然后,您可以简单地进行查找以遵循您的路径:
path_trace = [None]*n_frames
for step in range(n_frames):
path_trace[step] = possible_paths[current_state, step]
current_state = possible_paths[current_state, step]
一旦你有了你的路径,你就可以计算你的踪迹:
sigma = 0.1
trace = np.random.normal(loc=state_val[path_trace], scale=sigma, size=n_frames)
比较时间:
纯pythonfor
循环
%%timeit
trace_list = []
current_state = np.random.choice(state_idx)
for _ in range(n_frames):
trace_list.append(np.random.normal(loc=state_val[current_state], scale=sigma))
current_state = np.random.choice(state_idx, p=prob_nor[current_state, :])
结果:
30.1 ms ± 436 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
矢量化查找:
%%timeit
current_state = np.random.choice(state_idx)
path_trace = [None]*n_frames
possible_paths = np.vstack(
np.random.choice(state_idx, p=prob_nor[curr_state, :], size=n_frames)
for curr_state in range(n_states)
)
for step in range(n_frames):
path_trace[step] = possible_paths[current_state, step]
current_state = possible_paths[current_state, step]
trace = np.random.normal(loc=state_val[path_trace], scale=sigma, size=n_frames)
结果:
641 µs ± 6.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
加速约 50 倍。
推荐阅读
- javascript - 在数组上的映射中返回许多 .appendChild()
- ios - SwiftUI:ObservedObject 更新后的动画列表
- python-multiprocessing - ubuntu和centOS之间的Python多处理
- javascript - 在 Yarn 脚本中使用参数或环境变量?
- refactoring - 清理重复的打字稿代码(对象属性是否存在)
- c# - Http无法向服务器发送消息
- php - 将多个 POST 数据从 Android 发送到 API 级别 > 22 的 PHP 服务器?
- gradle - @Before 黄瓜套间
- java - 更新 @ManyToOne 关系中的拥有记录后,Hibernate 删除引用记录
- php - Laravel API 中的验证失败