tensorflow - tensorflow 断言元素 tf.where
问题描述
我有2个形状的矩阵:
pob.shape = (2,49,20)
rob.shape = np.zeros((2,49,20))
并且我想获取具有值!= 0 的 pob 元素的索引。所以在 numpy 我可以这样做:
x,y,z = np.where(pob!=0)
例如:
x = [2,4,7]
y = [3,5,5]
z = [3,5,6]
我想改变 rob 的价值:
rob[x1,y1,:] = np.ones((20))
我怎么能用 tensorflow 对象做到这一点?我尝试使用 tf.where 但无法从张量 obj 中获取索引值
解决方案
您可以使用tf.range()
andtf.meshgrid()
创建索引矩阵,然后tf.where()
在它们上使用您的条件来获得满足它的索引。my_tensor[my_indices] = my_values
但是,接下来会出现棘手的部分:您不能轻松地根据 TF ( )中的索引为张量赋值。
您的问题的解决方法(“for all (i,j,k)
, if pob[i,j,k] != 0
then rob[i,j] = 1
”)可能如下:
import tensorflow as tf
# Example values for demonstration:
pob_val = [[[0, 0, 0], [1, 0, 0], [1, 0, 1]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]]]
pob = tf.constant(pob_val)
pob_shape = tf.shape(pob)
rob = tf.zeros(pob_shape)
# Get the mask:
mask = tf.cast(tf.not_equal(pob, 0), tf.uint8)
# If there's at least one "True" in mask[i, j, :], make all mask[i, j, :] = True:
mask = tf.cast(tf.reduce_max(mask, axis=-1, keepdims=True), tf.bool)
mask = tf.tile(mask, [1, 1, pob_shape[-1]])
# Apply mask:
rob = tf.where(mask, tf.ones(pob_shape), rob)
with tf.Session() as sess:
rob_eval = sess.run(rob)
print(rob_eval)
# [[[0. 0. 0.]
# [1. 1. 1.]
# [1. 1. 1.]]
#
# [[1. 1. 1.]
# [0. 0. 0.]
# [0. 0. 0.]]]
推荐阅读
- c# - Rigidbody randomly speeds up when a list is modified (Unity)
- node.js - 如何让 VS Code 识别 tsconfig.json?
- html - 网格垂直居中文本 CSS
- java - 如何在 OkHttp 拦截器中获取方法的返回类型?
- c# - ARFoundation Button 在 Unity 上的哪些属性使其能够执行连续任务(与 onClick 不同)
- google-drive-api - (403) 放置时出现禁止错误,但重试时成功
- java - Jasper 报告在部署时无法使用 CJK 生成 PDF
- arrays - List.generate index null " 仅在 Textstyle, color: " 颤动
- flutter - 重新加载时,我应该如何更新无状态小部件?
- javascript - 密码错误时请求被卡住(Async/Await)