首页 > 解决方案 > 在 TensorFlow 中,如何断言列表的值在某个集合中?

问题描述

我有一个一维tf.uint8张量x,并想断言该张量内的所有值都在s我定义的集合中。s在图定义时是固定的,因此它不是动态计算的张量。

在普通的 Python 中,我想做某事。如下所示:

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
assert all(el in s for el in x), "This should fail, as 5 is not in s"

我知道我可以tf.Assert用于断言部分,但我正在努力定义条件部分(el in s)。最简单/最规范的方法是什么?

较旧的答案Determining if A Value is in a Set in TensorFlow对我来说是不够的:首先,写下来和理解很复杂,其次,它使用的是 broadcasted tf.equal,这比正确的集合更昂贵的计算基于检查。

标签: pythontensorflow

解决方案


一个简单的方法可能是这样的:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# Check every value in x against every value in s
xs_eq = tf.equal(x_t[:, tf.newaxis], s_t)
# Check every element in x is equal to at least one element in s
assert_op = tf.Assert(tf.reduce_all(tf.reduce_any(xs_eq, axis=1)), [x_t])
with tf.control_dependencies([assert_op]):
    # Use x_t...

这将创建一个大小为 的中间张量(len(x), len(s))。如果这有问题,您还可以将问题拆分为独立的张量,例如:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
# Count where each x matches each s
x_in_s = [tf.cast(tf.equal(x_t, si), tf.int32) for si in s]
# Add matches and check there is at least one match per x
assert_op = tf.Assert(tf.reduce_all(tf.add_n(x_in_s) > 0), [x_t])

编辑:

实际上,既然您说您的值是tf.uint8,您可以使用布尔数组使事情变得更好:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# One-hot vectors of values included in x and s
x_bool = tf.scatter_nd(tf.cast(x_t[:, tf.newaxis], tf.int32),
                       tf.ones_like(x_t, dtype=tf.bool), [256])
s_bool = tf.scatter_nd(tf.cast(s_t[:, tf.newaxis], tf.int32),
                       tf.ones_like(s_t, dtype=tf.bool), [256])
# Check that all values in x are in s
assert_op = tf.Assert(tf.reduce_all(tf.equal(x_bool, x_bool & s_bool)), [x_t])

这需要线性时间和恒定内存。

编辑 2:虽然最后一种方法在这种情况下理论上是最好的,但进行几个快速基准测试时,我只能在达到数十万个元素时看到性能上的显着差异,并且无论如何这三个仍然相当快速与tf.uint8


推荐阅读