首页 > 解决方案 > 如何使用字典映射 tf.data.Dataset 中的值

问题描述

这是所需映射的简单用例。将整数标签映射到 one-hot 编码。我想提一下,对于这种特殊情况,应该使用tf.one_hot. 但我想了解如何使用字典映射数据集。

import tensorflow as tf
import numpy as np

#CREATE A ONE-HOT ENCODING MAPPING
mike_labels = [164, 117, 132, 37, 66, 177, 225, 33, 28, 75, 7]
num_classes = len(mike_labels)
one_hots = np.eye(len(mike_labels))
one_hots = one_hots.tolist()
#used to convert labels to corresponding one-hot encoding
label_encoder = {orig: onehot for orig, onehot in zip(mike_labels, 
one_hots)}
print (label_encoder[164])
print (label_encoder[28])

#CREATE A FAKE DATASET
raw_data = [[164],[28],[132],[7]]
dataset = tf.data.Dataset.from_generator(lambda: raw_data, tf.float32, output_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    print(sess.run(next_element))
    print(sess.run(next_element))

该代码打印出 4 个值。第一个是直接从字典中获取的所需单热编码。后两个打印值是数据集中的前两个值。每个元素都显示为一个包含单个浮点数的列表。

[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
[ 164.]
[ 28.]

理想的答案将显示如何将数据集中的所有值更改为字典中相应的 one-hot 编码,使用提供的字典并且不会使用tf.one_hot.

标签: pythonpython-3.xtensorflow

解决方案


可以使用 lambda 函数映射标签。dataset.map 函数为数据集的每个元素调用该函数。映射中的 lambda 函数将使用 tf.py_func 调用另一个函数。

tf.py_func 允许将张量视为 np 数组,因为无法将张量提供给字典。该函数的返回值将是一个浮点数列表,tf.py_func 需要每个浮点数的数据类型,因此这是通过列表推导给出的:

dataset = dataset.map(lambda label: tf.py_func(practice_py_func, [label], [tf.float32 for i in range(num_classes)]))

将调用以下函数。首先,我们从接收到的 numpy 数组中获取一个列表。此列表包含单个元素(标签)。因此,我们获取位置 0 的元素并使用字典找到相应的 one-hot 编码。由于 tensorflow 似乎抛出了一个奇怪的错误,即接收到的值是双精度而不是预期的浮点数,因此我们将其转换为 float32。然后返回 one-hot 编码。

def practice_py_func(arg1):
    temp = arg1.tolist() #convert the numpy array to a list
    l = label_encoder[temp[0]] #look up the encoding in the dictionary
    output = [np.float32(val) for val in l] #convert each value in the encoding to a float
    return output

整个解决方案如下所示:

import tensorflow as tf
import numpy as np

#CREATE A ONE-HOT ENCODING MAPPING
mike_labels = [164, 117, 132, 37, 66, 177, 225, 33, 28, 75, 7]
num_classes = len(mike_labels)
one_hots = np.eye(len(mike_labels))
one_hots = one_hots.tolist()
#used to convert labels to corresponding one-hot encoding
label_encoder = {orig: onehot for orig, onehot in zip(mike_labels, one_hots)}
print (label_encoder[164])
print (label_encoder[28])

#CREATE A FAKE DATASET
raw_data = [[164],[28],[132],[7]]
dataset = tf.data.Dataset.from_generator(lambda: raw_data, tf.float32, output_shapes=[None])


def practice_py_func(arg1):
    temp = arg1.tolist() #convert the numpy array to a list
    l = label_encoder[temp[0]] #look up the encoding in the dictionary
    output = [np.float32(val) for val in l] #convert each value in the encoding to a float
    return output

dataset = dataset.map(lambda label: tf.py_func(practice_py_func, [label], [tf.float32 for i in range(num_classes)]))


iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    print(sess.run(next_element))
    print(sess.run(next_element))

推荐阅读