首页 > 解决方案 > Tensorflow:使用 slim.dataset.Dataset 时,有没有办法将标签 ID 值映射到其他值?

问题描述

dataset = slim.dataset.Dataset(...)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, ..._
image, labels = provider.get(['image', 'label')

假设,对于数据集 A 中的示例,labels可能是[1, 2, 1, 3]. 但是,出于某种原因(例如,由于数据集 B),我想将标签 ID 映射到其他值。映射可能如下所示。

# {old_label: target_label}
mapping = {0: 0, 1: 2, 2: 2, 3: 2, 4: 2, 5: 3, 6: 1}

目前,我猜测有两种方法:

--tf.data.Dataset似乎有一个map(map_func)每个例子都应该通过的功能,这可能是解决方案。不过,我比较熟悉slim.dataset.Dataset。有类似的技巧slim.dataset.Dataset吗?

- 我想知道我是否可以简单地将一些映射函数应用于张量label,例如:

new_labels = tf.map_fn(lambda x: x+1, labels, dtype=tf.int32)
# labels = [1 2 1 3] --> new_labels = [2 3 2 4]. This works.

new_labels = tf.map_fn(lambda x: mapping[x], labels, dtype=tf.int32)
# I wished but this does not work!

但是,下面没有工作,这是我需要的。有人可以建议吗?

标签: tensorflowtensorflow-slim

解决方案


我认为您可以尝试tf.contrib.lookup

keys = list(mapping.keys())
values = [mapping[k] for k in keys]
table = tf.contrib.lookup.HashTable(
  tf.contrib.lookup.KeyValueTensorInitializer(keys, values, key_dtype=tf.int64, value_dtype=tf.int64), -1
)
new_labels = table.lookup(labels)
sess=tf.Session()
sess.run(table.init)
print(sess.run(new_labels))

推荐阅读