python - tensorflow 多维索引
问题描述
我有
4维浮点张量
y
和3 维整数张量
y_index
,包含要提取的 y 的第 4 维的索引
我想要做的很简单,用 3 个带有 numpy 的 for 循环:
y = np.random.randint(100,size=(5,10,20,3))
y_index= np.random.randint(3,size=(5,10,20))
y_slice = np.zeros_like(y_index)
for i in range(y.shape[0]):
for j in range(y.shape[1]):
for k in range(y.shape[2]):
y_slice[i,j,k] = y[i,j,k,y_index[i,j,k]]
y_slice
我怎样才能在张量流中有效地做到这一点?我猜我需要使用 tf.gether_nd ...
解决方案
您可以执行以下操作。基本上,首先您将除最后一个之外的所有维度展平,y
并为 flatten 创建一个索引y
。您进行索引,然后重塑为正确的形状。
y = tf.constant(np.random.normal(size=(5,10,20,3)), dtype='float32')
y_index = tf.constant(np.random.randint(3, size=(5,10,20)), dtype='int32')
# Creating an index like [(0,y_index[0]), (1, y_index[1]), ...]
inds = tf.stack([tf.range(5*10*20),tf.reshape(y_index,[-1])],axis=1)
y_slice = tf.reshape(tf.gather_nd(tf.reshape(y,[-1,3]),inds),[5,10,20])
推荐阅读
- excel - 在单元格中输入数字以选择一系列单元格并打印它们
- scala - 当价值已经是未来时,使用 EitherT.liftF 提升未来
- vue.js - 从命令行安装 vue-router 插件
- python - 循环遍历正则表达式模式的列表/字典并提取字符串的 Pythonic 方式
- flutter - 为什么变量值在颤动中不刷新?
- asp.net-mvc - ASP.NET MVC 应用程序正在尝试对仅 POST 方法的 GET 请求
- swift - 在没有 UITextField Swift 的情况下获取文本输入
- sql - 防止按字母顺序删除重复数据
- python - 是否可以在 repl.it 中进行文档测试?
- r - 如何使用R中的变量为无根树的分支着色