python - 如何在 Tensorflow 中编写 LabelEncoder?
问题描述
我正在尝试将 Google Storage 上的目录解析为字符串,但我不断收到错误消息。我想找到每个文件的目录并将目录名称的数字编码作为数据集返回。这在使用 LabelEncoder 的 sklearn 中是微不足道的,但我在 Tensorflow 中遇到了麻烦。
CLASS_NAMES = [b'class_1', b'class_2', b'class_3']
labeler = tfds.features.ClassLabel(names=CLASS_NAMES)
def parse_filenames(filename):
label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
label = label.values[-2]
# Problem is in the two lines below
position_feature = tf.feature_column.categorical_column_with_vocabulary_list('label_names', CLASS_NAMES)
label = tf.io.parse_example(label, features=position_feature)
return label
folder = b'gs://<bucket>/train/*/*.jpg'
filenames_dataset = tf.data.Dataset.list_files(folder)
label_dataset = filenames_dataset.map(parse_filenames)
next(iter(label_dataset))
我收到一个错误ValueError: dictionary update sequence element #0 has length 16; 2 is required
如果我取出“# Problem is here”注释下的两行,它工作正常,除了它返回一个字符串而不是一个整数。我尝试过其他非张量流选项,例如 <list_name>.index(label),但这些选项当然会失败,因为一切都是张量而不是字符串。还有另一种方法可以做到这一点吗?
解决方案
也许你可以试试这一行而不是这两行:
label = tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))
你会得到类似的东西[0, 1, 0]
(标签的索引CLASS_NAMES
)。
功能和可重复的示例:
import tensorflow as tf
import numpy as np
from string import ascii_lowercase as letters
CLASS_NAMES = [b'class_1', b'class_2', b'class_3']
files = ['\\'.join([np.random.choice(CLASS_NAMES).decode(),
''.join(np.random.choice(list(letters), 5)) + '.jpg'])
for i in range(10)]
ds = tf.data.Dataset.from_tensor_slices(files)
这是我生成的假文件:
['class_3\\jrxog.jpg',
'class_1\\slfiq.jpg',
'class_2\\svldd.jpg',
'class_2\\avrgt.jpg',
'class_3\\wqwuv.jpg']
现在实现这个:
def get_label(file_path):
parts = tf.strings.split(file_path, '\\')
return file_path, tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))
ds = ds.map(get_label)
next(iter(ds))
(<tf.Tensor: shape=(), dtype=string, numpy=b'class_1\\bbqrx.jpg'>,
<tf.Tensor: shape=(), dtype=int64, numpy=0>)
推荐阅读
- azure - Azure WebJob 可以仅包含控制台应用程序,还是应仅使用 JobHost 构建以按需运行?
- java - 在邮件通知 Spring Boot Admin 中显示环境
- php - 我的 var_dump 停止工作,我该如何解决?
- mocha.js - 单元测试 redux-saga 选择
- python-3.x - 如何在 python scrapy 中修复蜘蛛内部的回调?
- amazon-web-services - 使用 CLI 将对象与过期和缓存控制标头同步到 S3?
- python-3.x - 套接字服务器 - readlines() 函数导致我的程序停止/卡住
- javascript - 如何使用 d3 使用 javascript 可视化本地 csv 文件?
- php - Codeigniter 应用程序中大量 mpdf 事务的服务器超时
- java - 使用名称创建 bean 时出错:注入自动装配的依赖项失败,无法解析占位符