python - 如何使用字典映射 tf.data.Dataset 中的值
问题描述
这是所需映射的简单用例。将整数标签映射到 one-hot 编码。我想提一下,对于这种特殊情况,应该使用tf.one_hot
. 但我想了解如何使用字典映射数据集。
import tensorflow as tf
import numpy as np
#CREATE A ONE-HOT ENCODING MAPPING
mike_labels = [164, 117, 132, 37, 66, 177, 225, 33, 28, 75, 7]
num_classes = len(mike_labels)
one_hots = np.eye(len(mike_labels))
one_hots = one_hots.tolist()
#used to convert labels to corresponding one-hot encoding
label_encoder = {orig: onehot for orig, onehot in zip(mike_labels,
one_hots)}
print (label_encoder[164])
print (label_encoder[28])
#CREATE A FAKE DATASET
raw_data = [[164],[28],[132],[7]]
dataset = tf.data.Dataset.from_generator(lambda: raw_data, tf.float32, output_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
print(sess.run(next_element))
print(sess.run(next_element))
该代码打印出 4 个值。第一个是直接从字典中获取的所需单热编码。后两个打印值是数据集中的前两个值。每个元素都显示为一个包含单个浮点数的列表。
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
[ 164.]
[ 28.]
理想的答案将显示如何将数据集中的所有值更改为字典中相应的 one-hot 编码,使用提供的字典并且不会使用tf.one_hot
.
解决方案
可以使用 lambda 函数映射标签。dataset.map 函数为数据集的每个元素调用该函数。映射中的 lambda 函数将使用 tf.py_func 调用另一个函数。
tf.py_func 允许将张量视为 np 数组,因为无法将张量提供给字典。该函数的返回值将是一个浮点数列表,tf.py_func 需要每个浮点数的数据类型,因此这是通过列表推导给出的:
dataset = dataset.map(lambda label: tf.py_func(practice_py_func, [label], [tf.float32 for i in range(num_classes)]))
将调用以下函数。首先,我们从接收到的 numpy 数组中获取一个列表。此列表包含单个元素(标签)。因此,我们获取位置 0 的元素并使用字典找到相应的 one-hot 编码。由于 tensorflow 似乎抛出了一个奇怪的错误,即接收到的值是双精度而不是预期的浮点数,因此我们将其转换为 float32。然后返回 one-hot 编码。
def practice_py_func(arg1):
temp = arg1.tolist() #convert the numpy array to a list
l = label_encoder[temp[0]] #look up the encoding in the dictionary
output = [np.float32(val) for val in l] #convert each value in the encoding to a float
return output
整个解决方案如下所示:
import tensorflow as tf
import numpy as np
#CREATE A ONE-HOT ENCODING MAPPING
mike_labels = [164, 117, 132, 37, 66, 177, 225, 33, 28, 75, 7]
num_classes = len(mike_labels)
one_hots = np.eye(len(mike_labels))
one_hots = one_hots.tolist()
#used to convert labels to corresponding one-hot encoding
label_encoder = {orig: onehot for orig, onehot in zip(mike_labels, one_hots)}
print (label_encoder[164])
print (label_encoder[28])
#CREATE A FAKE DATASET
raw_data = [[164],[28],[132],[7]]
dataset = tf.data.Dataset.from_generator(lambda: raw_data, tf.float32, output_shapes=[None])
def practice_py_func(arg1):
temp = arg1.tolist() #convert the numpy array to a list
l = label_encoder[temp[0]] #look up the encoding in the dictionary
output = [np.float32(val) for val in l] #convert each value in the encoding to a float
return output
dataset = dataset.map(lambda label: tf.py_func(practice_py_func, [label], [tf.float32 for i in range(num_classes)]))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
print(sess.run(next_element))
print(sess.run(next_element))
推荐阅读
- reactjs - MUI 5 主题 - 全局自定义排版和调色板
- javascript - 使用 Babel,我如何*不*编译掉类属性,因为浏览器现在原生支持它们?
- excel - 比较两个表并显示缺失的数据 - Power BI
- java - 如何从 Java 中的其他类调用变量?
- arrays - 合并 2 个 json inet 数组并选择 cidr 或等于的位置
- c++ - 二叉搜索树:克隆方法?
- java - 从 csv 文件中读取值并识别唯一名称
- c++ - 如何让铿锵声警告非常简单的缩小
- typescript - @mui/material 自动完成打字稿错误
- vba - POWERPOINT VBA 选择一组形状并更改颜色