首页 > 解决方案 > 如何在tensorflow2中改变像numpy和pytorch这样的张量?

问题描述

我想在 Tensorflow2.1 中实现 Informer(一种深度学习模型)。有一些这样的代码:

Q = tensor.shape(B,H,L,D) (for example, Q is a tensor and shape is (B,H,L,D) B is batch)
index = tensor.shape(B,C) (C < L)
Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :]
print(Q_reduce.shape)
# get shape is (B,B,C,D)

如何在 Tensorflow2 中获得相同的结果?我像这样使用 tf.gather:

Q_reduce = tf.gather(Q,index, axis=-2)
# get shape is (B,H,B,C,D)

标签: pythonnumpytensorflowpytorch

解决方案


推荐阅读