python - 如何使用 tf.gather_nd 从图像中进行选择?
问题描述
我有一个X
形状为 (N,256,256,5) 的 CNN 输出张量,其中N
是批量维度。我有张量x
并y
包含 N 个索引(每个 0 到 255)。我想使用这些索引来形成一个 (N,5) 张量Y
,使得Y[n,:] = X[n, x[n], y[n], :]
. 如何才能做到这一点?
解决方案
我认为与此类似的事情可以为您解决问题(如果我正确理解了您的问题):
您的数据:
import tensorflow as tf
import numpy as np
batch_size = 5
D=2
data = tf.constant(np.array(range(batch_size * D * D * 5)).reshape([batch_size, D, D, 5]))
计算指数:
batches = tf.reshape(tf.range(batch_size, dtype=tf.int32), shape=[batch_size, 1])
random_x = tf.random.uniform([batch_size, 1], minval = 0, maxval = D, dtype = tf.int32)
random_y = tf.random.uniform([batch_size, 1], minval = 0, maxval = D, dtype = tf.int32)
indices = tf.concat([batches, random_x, random_y], axis=1)
请注意,random_x
andrandom_y
可以替换为您现有的张量x
和y
张量。然后使用该tf.gather_nd
功能将您的张量应用于您indices
的张量data
:
output = tf.gather_nd(data, indices)
print(batches, 'batches')
print(random_x, 'random_x')
print(random_y, 'random_y')
print(indices, 'indices')
print('Original tensor \n', data, '\n')
print('Updated tensor \n', output)
'''
tf.Tensor(
[[0]
[1]
[2]
[3]
[4]], shape=(5, 1), dtype=int32) batches
tf.Tensor(
[[0]
[1]
[1]
[0]
[1]], shape=(5, 1), dtype=int32) random_x
tf.Tensor(
[[0]
[1]
[0]
[0]
[0]], shape=(5, 1), dtype=int32) random_y
tf.Tensor(
[[0 0 0]
[1 1 1]
[2 1 0]
[3 0 0]
[4 1 0]], shape=(5, 3), dtype=int32) indices
Original tensor
tf.Tensor(
[[[[ 0 1 2 3 4]
[ 5 6 7 8 9]]
[[10 11 12 13 14]
[15 16 17 18 19]]]
[[[20 21 22 23 24]
[25 26 27 28 29]]
[[30 31 32 33 34]
[35 36 37 38 39]]]
[[[40 41 42 43 44]
[45 46 47 48 49]]
[[50 51 52 53 54]
[55 56 57 58 59]]]
[[[60 61 62 63 64]
[65 66 67 68 69]]
[[70 71 72 73 74]
[75 76 77 78 79]]]
[[[80 81 82 83 84]
[85 86 87 88 89]]
[[90 91 92 93 94]
[95 96 97 98 99]]]], shape=(5, 2, 2, 5), dtype=int32)
Updated tensor
tf.Tensor(
[[ 0 1 2 3 4]
[35 36 37 38 39]
[50 51 52 53 54]
[60 61 62 63 64]
[90 91 92 93 94]], shape=(5, 5), dtype=int32)
'''
张量output
的形状为(batch_size, 5)
。正如我所说,我不确定我是否理解了这个问题,所以请随时提供一些反馈。
推荐阅读
- image-processing - 我的测试损失达到数百万是否正常
- performance - Haskell 分析中的杂质或随机性
- python - 如何在 Python 中结合使用 Request 和 BeautifulSoup 来加速 Webscraping?
- python - Python如何找到椭圆周围每个点的坐标
- angular - Angular - map 将嵌套数组从初始对象转换为空数组
- android - 如何获取 JSON 嵌套对象
- reactjs - FlatList 不使用 React Hooks 渲染
- laravel - 检查“comment”是否是第一个,然后不要删除它
- c - 为什么我在分叉时会得到这个变量的两个值?
- python - 无法将张量添加到批次:元素数量不匹配。形状是:[张量]:[2],[批次]:[5]