首页 > 解决方案 > 未排序的段 argmax 解决方法 tensorflow

问题描述

我正在尝试通过tf_boolean_mask索引的值创建一个从张量中过滤重复索引的过滤器。如果该值大于应保留的重复值,则丢弃其他值。如果 index 和 value 相同,则应仅保留一个:

[Pseudocode]
for index in indices
    If index is unique:
        keep index = True
    else:
        if val[index] > val[index of all other duplicate indices]:
            keep index = True
        elif val[index] < val[index of any other duplicate indices]:
            keep index = False
        elif val[index] == val[index of any other duplicate indices]:
            keep only a single one of the equal indices(doesn't matter which)   

该问题的一个简短示例如下:

import tensorflow as tf
tf.enable_eager_execution()

index = tf.convert_to_tensor([  10,    5,   20,    20,    30,    30])
value = tf.convert_to_tensor([  1.,   0.,   2.,    0.,    0.,    0.])
# bool_mask =                [True, True, True, False,  True, False]
# or                         [True, True, True, False, False,  True]
# the index 3 is filtered because index 2 has a greater value (2 comp. to 0)
# The index 4 and 5 are identical in their respective values, that's why both
# of them can be kept, but at maximum one of them. 


...
bool_mask = ?

我目前的方法成功地解决了删除具有不同值的重复项,但在具有相同值的重复项上失败了。然而,这是一个不幸出现在我的数据中的边缘案例:

import tensorflow as tf

y, idx = tf.unique(index) 
num_segments = tf.shape(y)[0]
maximum_vals = tf.unsorted_segment_max(value, idx, num_segments)

fused_filt = tf.stack([tf.cast(y, tf.float32), maximum_vals],axis=1)
fused_orig = tf.stack([tf.cast(index, tf.float32), value], axis=1)

fused_orig_tiled = tf.tile(fused_orig, [1, tf.shape(fused_filt)[0]])
fused_orig_res = tf.reshape(fused_orig_tiled, [-1, tf.shape(fused_filt)[0], 2])

comp_1 = tf.equal(fused_orig_res, fused_filt)
comp_2 = tf.reduce_all(comp_1, -1)
comp_3 = tf.reduce_any(comp_2, -1)
# comp_3 = [True, True, True, False, True, True]

纯粹的 tensorflow 解决方案会很好,因为可以相当简单地实现索引上的 For 循环。谢谢你。

标签: pythontensorflow

解决方案


推荐阅读