首页 > 解决方案 > 如何索引具有形状 (batch_size, 200, 256) 的张量以获得 (batch_size, 1, 256) 给定长度 = batch_size 的索引张量列表?

问题描述

我有形状为 (batch_size, 200, 256) 的 LSTM 层的输出,其中 200 是标记序列的长度,256 是 LSTM 输出维度。我还有另一个形状为 (batch_size) 的张量,它是我想从批次中的每个样本序列中切出的标记的索引列表。

如果令牌索引不是 -1,我将切出一个令牌向量表示(长度 = 256)。如果令牌索引为 -1,我将给出零向量(长度 = 256)。

预期的输出结果具有形状 (batch_size, 1, 256)。我该怎么做?

谢谢

这是我到目前为止尝试过的

bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) 
dropout = Dropout(params['dropout_rate'])(bidir)
def slice_by_tensor(x):
    matrix_to_slice = x[0]
    index_tensor = x[1]


    out_tensor = tf.where(index_tensor == -1, 
                          tf.zeros(tf.shape(tf.gather(matrix_to_slice, 
                                                      index_tensor, axis=1))), 
                          tf.gather(matrix_to_slice, index_tensor, axis=1))



    return out_tensor


representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) 
# stack_idx0 shape is (batch_size) 
# I got output with shape (batch_size, batch_size, 256) with this code

标签: pythonpython-3.xtensorflowkeras

解决方案


a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))
#     [[[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#      [[12, 13, 14, 15],
#      [16, 17, 18, 19],
#       [20, 21, 22, 23]]]

b=tf.constant([-1,2]) 

aa=tf.pad(a,[[0,0],[1,0],[0,0]]) 

bb=b+1 

index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) 
res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)
#[[[ 0,  0,  0,  0]],
#[[20, 21, 22, 23]]]

当 index 为 -1 时,我们需要像张量这样的零。所以我们可以先沿第二个轴填充原始张量。然后将索引增加 1。在此之后,使用tf.gather_nd将返回答案。


推荐阅读