python - 从特定轴嵌入查找
问题描述
我有两个张量。
v, shape=(50, 64, 128), dtype=float32
m, shape=(64, 50, 1), dtype=int32
m 中的值是介于 0 和 50 (<=49) 之间的整数r: shape=(64, 50, 128), dtype=float32
例如值r(i, j, 0-128) = v(m(i, j), i, 0-128)
我看到的最接近的是 tf.nn.embedding_lookup 但我不确定如何在这个用例中使用它
解决方案
您可以使用以下tf.nn.embedding_lookup
或tf.gather_nd
方法来实现您的目标。
import tensorflow as tf
import numpy as np
m_np = np.random.randint(0,50,(64, 50, 1))
m = tf.constant(m_np)
n = tf.random.normal((50, 64, 128))
# Method 1
tmp = tf.nn.embedding_lookup(n,m[:,:,0]) # shape=(64,50,64,128)
tmp = tf.transpose(tmp,[1,3,0,2]) # shape=(50,128,64,64)
result1 = tf.transpose(tf.matrix_diag_part(tmp),[2,0,1]) # shape=(64,50,128)
# Method 2
indices = tf.tile(tf.reshape(tf.range(64),(-1,1,1)),(1,50,1)) # shape=(64,50,1)
indices = tf.concat([m,indices],axis=-1) # shape=(64,50,2)
result2 = tf.gather_nd(n,indices) # shape=(64,50,128)
with tf.Session() as sess:
# Randomly select a location for test
n_value,result_value = sess.run([n,result1])
print((n_value[m_np[5,4],5,:]==result_value[5,4]).all())
# True
推荐阅读
- tensorflow - 来自 TensorFlow 的这条红色消息是错误的吗?
- java - 计算属性的对象 ArrayList 中的出现次数
- android - Paytm 集成问题:android 和 nodejs 上的校验和无效
- javascript - Vue:按住键盘修饰符时如何更改按钮的文本?
- angular - 如何修复错误:类型 FroalaEditorModule 没有“ngModuleDef”属性使用 Froala Wysiwyg
- gulp - GUlp - 以下任务未完成
- java - 需要使用 64 位深度的 BufferedImages 创建验证码图像运行时?java支持吗?
- python - Pandas 将部分列索引转换为日期时间
- django - 在具有 2 个或更多副本的容器化 django 应用程序中运行 CronJob 的最佳方法是什么
- google-analytics - 使用用户会话属性设置 Google Analytics(分析)目标