首页 > 解决方案 > 在 tensorflow 中创建顶部“k”元素掩码的更有效方法

问题描述

我想创建一个函数,对于给定的1d-tensor 输出掩码,位置在哪里,对应于那里和其他地方的最高k值。即,我有例如:10

tensor = [1, 0, 7, 5, 2, 3] : get_largest_mask(tensor, 3) = [0, 0, 1, 1, 0, 1]

我创建了以下功能:

def get_largest_mask(tensor, n_to_keep):
    # tensor 1-d tensor
    values, indices = tf.math.top_k(tensor, k=n_to_keep)

    mask = tf.zeros(tf.size(tensor))
    mask = tf.tensor_scatter_nd_update(mask, [[idx] for idx in indices], tf.ones(n_to_keep))

    return mask

然而,对于感兴趣的情况,它的工作速度相当慢,而且正如我所测量的那样,大部分时间都是由tf.tensor_scatter_nd_update. 什么是更快的选择?

张量的典型大小是10^3-10^4元素k,顺序为 '10^2-10^3'

标签: pythonarraystensorflow

解决方案


我会去寻找前 K 值,然后是大小比较

import tensorflow as tf

tensor = tf.convert_to_tensor([1, 0, 7, 5, 2, 3])

mask = tf.cast(tensor >= tf.math.top_k(tensor, 3)[0][-1], tf.int32)
# mask = <tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 1, 1, 0, 1], dtype=int32)>

解释

tf.math.top_k返回两个值,第一个是具有实际前 k 个值的张量,第二个是索引。我们取值,然后访问[-1]最小值。>=然后我们通过提问来创建面具。最后,我们根据您要求的输出转换为整数


推荐阅读