首页 > 解决方案 > 简化 numpy argmin 的愚蠢循环

问题描述

如何简化我在这里以 numpy 风格编写的循环?

X     = np.random.random([10,15,20])
Y     = np.random.random([10,15,20,5])
Z     = np.zeros([10,15,5])

min_X = np.argmin(X,axis=2)
for i in range(10):
    for j in range(15):
        Z[i,j,:] = Y[i,j,min_X[i,j],:]

标签: pythonnumpy

解决方案


有 NumPy 内置 -np.take_along_axis为此(需要一些额外的步骤,因为它需要索引数组具有相同数量的暗淡) -

np.take_along_axis(Y,min_X[...,None,None],axis=2)[...,0,:]

推荐阅读