首页 > 解决方案 > tensorflow - 如何选择数组中除索引序列之外的所有元素?

问题描述

可以使用此处np.delete指定的方式完成等效的 numpy 操作。由于没有,我不确定如何在.tf.deletetensorflow

标签: pythontensorflow

解决方案


我想你可能想使用 tf.boolean_mask。例如,

labels = tf.Variable([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
a = tf.Variable([1, 0, 0])
a1 = tf.cast(a, dtype=tf.bool)
print(a1)    
mask = tf.math.logical_not(a1)
print(mask)
print(tf.boolean_mask(labels, mask))

输出是,

tf.Tensor([ True False False], shape=(3,), dtype=bool)
tf.Tensor([False  True  True], shape=(3,), dtype=bool)
tf.Tensor(
[[0 1 0]
 [0 0 1]], shape=(2, 3), dtype=int32)

因此,您可以定义一个掩码来删除一维张量的特定向量。


推荐阅读