首页 > 解决方案 > tensorflow 断言元素 tf.where

问题描述

我有2个形状的矩阵:

pob.shape = (2,49,20)  
rob.shape = np.zeros((2,49,20))  

并且我想获取具有值!= 0 的 pob 元素的索引。所以在 numpy 我可以这样做:

x,y,z = np.where(pob!=0)

例如:

x = [2,4,7]  
y = [3,5,5]  
z = [3,5,6]

我想改变 rob 的价值:

rob[x1,y1,:] = np.ones((20))

我怎么能用 tensorflow 对象做到这一点?我尝试使用 tf.where 但无法从张量 obj 中获取索引值

标签: tensorflowz-index

解决方案


您可以使用tf.range()andtf.meshgrid()创建索引矩阵,然后tf.where()在它们上使用您的条件来获得满足它的索引。my_tensor[my_indices] = my_values但是,接下来会出现棘手的部分:您不能轻松地根据 TF ( )中的索引为张量赋值。

您的问题的解决方法(“for all (i,j,k), if pob[i,j,k] != 0then rob[i,j] = 1”)可能如下:

import tensorflow as tf

# Example values for demonstration:
pob_val = [[[0, 0, 0], [1, 0, 0], [1, 0, 1]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]]]
pob = tf.constant(pob_val)
pob_shape = tf.shape(pob)
rob = tf.zeros(pob_shape)

# Get the mask:
mask = tf.cast(tf.not_equal(pob, 0), tf.uint8)

# If there's at least one "True" in mask[i, j, :], make all mask[i, j, :] = True:
mask = tf.cast(tf.reduce_max(mask, axis=-1, keepdims=True), tf.bool)
mask = tf.tile(mask, [1, 1, pob_shape[-1]])

# Apply mask:
rob = tf.where(mask, tf.ones(pob_shape), rob)

with tf.Session() as sess:
    rob_eval = sess.run(rob)
    print(rob_eval)
    # [[[0. 0. 0.]
    #   [1. 1. 1.]
    #   [1. 1. 1.]]
    #
    #  [[1. 1. 1.]
    #   [0. 0. 0.]
    #   [0. 0. 0.]]]

推荐阅读