python - 如何在 Tensorflow 2 中进行索引查找(word2idx)?
问题描述
假设我有一个包含 1000 个单词的 word2idx 字典:
{'a': 0, 'b': 1, ..., 'zzz': 999}
并将 OOV 映射到 1000。
我想在 Tensorflow 2 中进行此查找,例如,给定
text = tf.ragged.constant([["a", "b"], ["zzz"], ["abcd"]])
结果将是
indices = tf.ragged.constant([[0, 1], [999], [1000]])
我怎么能这样做?
解决方案
使用tf.ragged.map_flat_values
:
text = tf.ragged.constant([[["a", "b"], ["zzz"], ["c",'N']]])
d = {'a': 0, 'b': 1, 'c': 2, 'zzz': 999, 'z':8}
OOV_CODE = 1000
lookup = lambda x: [d.get(key.numpy().decode("utf-8"), OOV_CODE) for key in x]
indices = tf.ragged.map_flat_values(lookup, text)
<tf.RaggedTensor [[[0, 1], [999], [2, 1000]]]>
对于仅矢量化 TF 方式,请考虑tf.lookup
工具:
d = {'a': 0, 'b': 1, 'c': 2, 'zzz': 999, 'z':8}
t = tf.ragged.constant([["a", "b"], ["zzz"], ["c"], ['N','M']])
ti = tf.lookup.KeyValueTensorInitializer(
list(d.keys()), list(d.values()), key_dtype=None, value_dtype=tf.int64, name=None)
lu = tf.lookup.StaticVocabularyTable(ti, num_oov_buckets = 1 )
tf.ragged.map_flat_values(lu.lookup, t)
<tf.RaggedTensor [[0, 1], [999], [2], [5, 5]]>
请注意,它5
用作 OOV 值 - 显然它是最小的自由正值。如果您在示例中输入值 0-999,OOV 代码自然会是 1000。
推荐阅读
- android-studio - 将包导入flutter(Andriod Studio)时面临一个大问题。我已经尝试了其他解决方案的所有方法
- laravel - Lumen 主页在 AWS 上有效,但其他主页无效
- machine-learning - tensorflow.js 模型不学习
- c# - 谁能建议一种简单的方法来使用没有 DataAdapter 的 Microsoft.Data.Sqlite 使用更改的数据更新 SQLite 数据库?
- highcharts - 如果文本太长,则从头开始显示 highcharts 节点文本
- flutter - Flutter 获取文件名
- c++ - 一种对齐排序结构成员的方法 - 月份和降雨量
- python - django查询日期字段的年份未被提取
- css - div#sidebar 和#sidebar div 有什么区别?
- reactjs - 将所有模块组合在一页上时“无法获取/”?