首页 > 解决方案 > 用 tf.gather 或 tf.gather_nd 切片

问题描述

我有一个大小为 [batch_size x actions_space x N_quantiles] 的张量。为了这个例子,假设维度是 2、3 和 4。

x_test = 
 <tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.71722096, -0.36535808, -0.00286232,  0.37722322],
        [ 0.93776643, -1.146626  ,  0.1840729 , -1.427474  ],
        [ 0.47025302, -0.92792755, -0.1490136 ,  1.495174  ]],

       [[-1.3838278 , -0.54772085, -0.14298695,  0.39195213],
        [-0.7986407 ,  0.6419045 , -0.8136323 ,  0.9346474 ],
        [ 0.96690583, -0.82267016, -0.51641494,  0.6930123 ]]],
      dtype=float32)>

对于每个批次,我都有一个动作的索引,我想减去这个动作的分位数。所以我想最终得到一个大小为 [Batch_size x N_Quantiles] = [2 x 4] 的数组。

如果我的动作索引是 [2,0],那么我想以数组结尾:

[[ 0.47025302, -0.92792755, -0.1490136 ,  1.495174  ],
[-1.3838278 , -0.54772085, -0.14298695,  0.39195213 ]].

如何使用 tf.gather 或 tf.gather_nd 解决此问题。这应该很简单,但我真的很难提取正确的数组。我尝试过类似的东西:

tf.gather(x_test, actions, axis=1) 

但没有什么能正常工作

标签: pythontensorflow

解决方案


尝试tf.gather(x_test, actions, batch_dims=1)


推荐阅读