python - replacement of "tf.gather_nd"
问题描述
I am doing a project but their tensorflow version does not support tf.gather_nd. I am asking if possible that use tf.gather, tf.slice or tf.strided_slice to rewrite a function of tf.gather_nd?
tf.gather_nd is used to gather slices from a tensor into a Tensor with shape specified by indices. details can be found in https://www.tensorflow.org/api_docs/python/tf/gather_nd
Thanks,
解决方案
这个函数应该做同样的工作:
import tensorflow as tf
import numpy as np
def my_gather_nd(params, indices):
idx_shape = tf.shape(indices)
params_shape = tf.shape(params)
idx_dims = idx_shape[-1]
gather_shape = params_shape[idx_dims:]
params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
axis_step = tf.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
indices_flat = tf.reduce_sum(indices * axis_step, axis=-1)
result_flat = tf.gather(params_flat, indices_flat)
return tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0))
# Test
np.random.seed(0)
with tf.Graph().as_default(), tf.Session() as sess:
params = tf.constant(np.random.rand(10, 20, 30).astype(np.float32))
indices = tf.constant(np.stack([np.random.randint(10, size=(5, 8)),
np.random.randint(20, size=(5, 8))], axis=-1))
result1, result2 = sess.run((tf.gather_nd(params, indices),
my_gather_nd(params, indices)))
print(np.allclose(result1, result2))
# True
推荐阅读
- javascript - Window.print() 在弹出窗口中的 ipad chrome 中不起作用[window.open()]
- angular - 当我尝试对用户进行身份验证时出现 unsupported_grant_type 错误
- html - 表单包装器没有响应
- raspberry-pi3 - rplidar 输入缓冲区中的测量值过多
- c# - Application_BeginRequest - 首先调用基方法
- amazon-s3 - 在 Salesforce.com 中访问 S3 对象
- mysql - MySQL 到 MSSQL 复制
- django - Django REST framework RetrieveAPIView 获取空的“id”参数并返回 404 错误
- javascript - 无法安装 jest
- android - 来自字符串标识符的字符串键