首页 > 解决方案 > 如何在 Tensorflow 2 中进行索引查找(word2idx)?

问题描述

假设我有一个包含 1000 个单词的 word2idx 字典:

{'a': 0, 'b': 1, ..., 'zzz': 999}

并将 OOV 映射到 1000。

我想在 Tensorflow 2 中进行此查找,例如,给定

text = tf.ragged.constant([["a", "b"], ["zzz"], ["abcd"]])

结果将是

indices = tf.ragged.constant([[0, 1], [999], [1000]])

我怎么能这样做?

标签: pythontensorflowkerastensorflow2.0

解决方案


使用tf.ragged.map_flat_values

text = tf.ragged.constant([[["a", "b"], ["zzz"], ["c",'N']]])
d = {'a': 0, 'b': 1, 'c': 2, 'zzz': 999, 'z':8}
OOV_CODE = 1000
lookup = lambda x: [d.get(key.numpy().decode("utf-8"), OOV_CODE) for key in x]
indices = tf.ragged.map_flat_values(lookup, text)
    
<tf.RaggedTensor [[[0, 1], [999], [2, 1000]]]>

对于仅矢量化 TF 方式,请考虑tf.lookup工具:

    d = {'a': 0, 'b': 1, 'c': 2, 'zzz': 999, 'z':8}
    t = tf.ragged.constant([["a", "b"], ["zzz"], ["c"], ['N','M']])
    
    ti = tf.lookup.KeyValueTensorInitializer(
                list(d.keys()), list(d.values()), key_dtype=None, value_dtype=tf.int64, name=None)
    
    lu = tf.lookup.StaticVocabularyTable(ti, num_oov_buckets = 1 )
    
    tf.ragged.map_flat_values(lu.lookup, t)

<tf.RaggedTensor [[0, 1], [999], [2], [5, 5]]>

请注意,它5用作 OOV 值 - 显然它是最小的自由正值。如果您在示例中输入值 0-999,OOV 代码自然会是 1000。


推荐阅读