首页 > 解决方案 > 使用矩阵的张量索引

问题描述

我有矩阵(3 x 15)dummies,其中标记序列为行:

[[ 1 66 67 68  0  0  0  0  0  0  0  0  0  0  0]
[ 1 66 67 66 68 66 67 66  0  0  0  0  0  0  0]
[ 1 66 67 68 18 19 20 21 22 23 24 25 26 17  0]]

此外,还有一个probs形状张量 (3 x 15 x n_tokens) 和令牌概率。

probs我只需要选择dummies.

我认为,可以将矩阵用作张量的索引,但我还没有找到如何做到这一点。

标签: pythontensorflow

解决方案


你可以这样做:

import tensorflow as tf

dummies = ...
probs = ...
s = tf.shape(dummies)
i = tf.range(s[0])
j = tf.range(s[1])
ii, jj = tf.meshgrid(i, j, indexing='ij')
idx = tf.stack([ii, jj, dummies], axis=-1)
result = tf.gather_nd(probs, idx)

推荐阅读