首页 > 解决方案 > 如何用给定的索引索引numpy数组?

问题描述

问题:如何用给定的索引索引 numpy 数组?

说明

在强化学习中,我得到了许多对应于不同状态的离散分布,如下所示:

import numpy as np
distributions = np.array([[0.1,0.2,0.7],[0.3,0.3,0.4],[0.2,0.2,0.6]])

# array([[0.1, 0.2, 0.7],  # \pi(s0)
#        [0.3, 0.3, 0.4],  # \pi(s1)
#        [0.2, 0.2, 0.6]]) # \pi(s2)

然后,我想分别获得在 state 中采取行动 0、在 states0中采取行动 2 和在 states1中采取行动 1的概率s2

所以我将索引值存储在一个数组中,如下所示:

actions = np.array([[0],[2],[1]])

# array([[0],  # taking action 0 in state s0
#        [2],  # taking action 2 in state s1
#        [1]]) # taking action 1 in state s2

我期望得到的。

我想使用 索引distributionsactions并期望得到如下结果:

# array([0.1,0.4,0.2])
# or 
# array([[0.1],
#        [0.4],
#        [0.2]])

我试过了。

我试过np.take(distributions, actions)了,但retunarray([0.1, 0.7, 0.2])显然是我想要的。我也试过distributions[:,actions]了,这给了我另一个错误的答案,如下所示:

array([[0.1, 0.7, 0.2],
       [0.3, 0.4, 0.3],
       [0.2, 0.6, 0.2]])         

问题

我能做些什么来解决这个问题?

标签: pythonnumpy

解决方案


In [614]: distributions = np.array([[0.1,0.2,0.7],[0.3,0.3,0.4],[0.2,0.2,0.6]]) 
     ...:                                                                       
In [615]: actions = np.array([[0],[2],[1]])  

使用 [0,1,2] 行索引:

In [616]: distributions[np.arange(3), actions]                                  
Out[616]: 
array([[0.1, 0.3, 0.2],
       [0.7, 0.4, 0.6],
       [0.2, 0.3, 0.2]])

哎呀,actions是 (3,1) 形状,它与 (3,) 一起广播以产生 (3,3) 选择。相反,我们想使用 (3,) 形状actions

In [617]: distributions[np.arange(3), actions.ravel()]                          
Out[617]: array([0.1, 0.4, 0.2])

或得到 (3,1) 结果:

In [619]: distributions[[[0],[1],[2]], actions]                                 
Out[619]: 
array([[0.1],
       [0.4],
       [0.2]])

推荐阅读