首页 > 解决方案 > 从特定轴嵌入查找

问题描述

我有两个张量。

v, shape=(50, 64, 128), dtype=float32
m, shape=(64, 50, 1), dtype=int32

m 中的值是介于 0 和 50 (<=49) 之间的整数r: shape=(64, 50, 128), dtype=float32

例如值r(i, j, 0-128) = v(m(i, j), i, 0-128)

我看到的最接近的是 tf.nn.embedding_lookup 但我不确定如何在这个用例中使用它

标签: pythontensorflow

解决方案


您可以使用以下tf.nn.embedding_lookuptf.gather_nd方法来实现您的目标。

import tensorflow as tf
import numpy as np

m_np = np.random.randint(0,50,(64, 50, 1))
m = tf.constant(m_np)
n = tf.random.normal((50, 64, 128))

# Method 1
tmp = tf.nn.embedding_lookup(n,m[:,:,0]) # shape=(64,50,64,128)
tmp = tf.transpose(tmp,[1,3,0,2]) # shape=(50,128,64,64)
result1 = tf.transpose(tf.matrix_diag_part(tmp),[2,0,1]) # shape=(64,50,128)

# Method 2
indices = tf.tile(tf.reshape(tf.range(64),(-1,1,1)),(1,50,1)) # shape=(64,50,1)
indices = tf.concat([m,indices],axis=-1) # shape=(64,50,2)
result2 = tf.gather_nd(n,indices) # shape=(64,50,128)

with tf.Session() as sess:
    # Randomly select a location for test
    n_value,result_value = sess.run([n,result1])
    print((n_value[m_np[5,4],5,:]==result_value[5,4]).all())

# True

推荐阅读