python - 在 TensorFlow 中,如何断言列表的值在某个集合中?
问题描述
我有一个一维tf.uint8
张量x
,并想断言该张量内的所有值都在s
我定义的集合中。s
在图定义时是固定的,因此它不是动态计算的张量。
在普通的 Python 中,我想做某事。如下所示:
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
assert all(el in s for el in x), "This should fail, as 5 is not in s"
我知道我可以tf.Assert
用于断言部分,但我正在努力定义条件部分(el in s
)。最简单/最规范的方法是什么?
较旧的答案Determining if A Value is in a Set in TensorFlow对我来说是不够的:首先,写下来和理解很复杂,其次,它使用的是 broadcasted tf.equal
,这比正确的集合更昂贵的计算基于检查。
解决方案
一个简单的方法可能是这样的:
import tensorflow as tf
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# Check every value in x against every value in s
xs_eq = tf.equal(x_t[:, tf.newaxis], s_t)
# Check every element in x is equal to at least one element in s
assert_op = tf.Assert(tf.reduce_all(tf.reduce_any(xs_eq, axis=1)), [x_t])
with tf.control_dependencies([assert_op]):
# Use x_t...
这将创建一个大小为 的中间张量(len(x), len(s))
。如果这有问题,您还可以将问题拆分为独立的张量,例如:
import tensorflow as tf
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
x_t = tf.constant(x, dtype=tf.uint8)
# Count where each x matches each s
x_in_s = [tf.cast(tf.equal(x_t, si), tf.int32) for si in s]
# Add matches and check there is at least one match per x
assert_op = tf.Assert(tf.reduce_all(tf.add_n(x_in_s) > 0), [x_t])
编辑:
实际上,既然您说您的值是tf.uint8
,您可以使用布尔数组使事情变得更好:
import tensorflow as tf
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# One-hot vectors of values included in x and s
x_bool = tf.scatter_nd(tf.cast(x_t[:, tf.newaxis], tf.int32),
tf.ones_like(x_t, dtype=tf.bool), [256])
s_bool = tf.scatter_nd(tf.cast(s_t[:, tf.newaxis], tf.int32),
tf.ones_like(s_t, dtype=tf.bool), [256])
# Check that all values in x are in s
assert_op = tf.Assert(tf.reduce_all(tf.equal(x_bool, x_bool & s_bool)), [x_t])
这需要线性时间和恒定内存。
编辑 2:虽然最后一种方法在这种情况下理论上是最好的,但进行几个快速基准测试时,我只能在达到数十万个元素时看到性能上的显着差异,并且无论如何这三个仍然相当快速与tf.uint8
。
推荐阅读
- postgresql - 如果字符串是 PostgreSQL 11.0 中另一列的子字符串,则选择行
- java - 如何在有状态 EJB 中获取 UISessionID?
- linux - 删除具有特定名称的所有文件夹的内容
- python - Torch 为多个 GPU 并行化任务
- javascript - 是否可以在 HTML/JS/Fluid 的最小/最大范围之外添加附加值?
- ios - Vapor 代码解释有困难,执行时出现401错误
- angular - 添加 Angular SSR 引发错误发生未处理的异常:npm ERR!代码 ERESOLVE
- c++ - 你能找出一对结构吗?
- jquery - 使用循环使用 jQuery 为单个列表项设置动画
- java - Java如何获取用户数据目录