python - 将列表项映射到 tensorflow 数据集字典
问题描述
我正在尝试将图像信息映射到由图像和标签字典组成的数据集。
parse_function()
应该只从 2 个文件名路径和标签列表中解码。
def parse_function(filename, label):
image_string = tf.io.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize(image_decoded, [4, 4])
return image_resized, label
def dataset_maker(list_sample_paths, list_labels):
filenames = tf.constant(list_sample_paths)
labels = tf.constant(list_labels)
dataset = tf.data.Dataset.from_tensor_slices({"image": filenames, "label": labels})
dataset = dataset.map(parse_function)
training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)
但我收到此错误消息
TypeError: tf__parse_function() missing 1 required positional argument: 'label'
在这种情况下我需要使用字典理解吗?非常感谢解决此问题的任何帮助。谢谢!
在 Srihari Humbarwadi 回复后添加此信息以使用元组解决它: 我想获得一个字典结构,因为我用 Mnist 为我的模型下雨了。
一个随机的 Mnist 示例具有以下结构:
{'image': <tf.Tensor: id=140275, shape=(28, 28, 1), dtype=uint8, numpy=array([[[ 0],[ 0],[ 0]],dtype=uint8)>, 'label': <tf.Tensor: id=140276, shape=(), dtype=int64, numpy=6>}
解决方案
您不需要以字典的形式传递文件名和标签列表。你可以通过传递一个元组来让它工作,即。(filenames, labels)
. 这是我使用的完整代码:
from glob import glob
import numpy as np
import tensorflow as tf
print('TensorFlow:', tf.__version__)
list_training_sample_paths = sorted(glob('images/*'))
# random integer labels
list_training_sample_labels = np.random.randint(low=0, high=5, size=[len(list_training_sample_paths)])
def parse_function(filename, label):
image_string = tf.io.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize(image_decoded, [4, 4])
return image_resized, label
def dataset_maker(list_sample_paths, list_labels):
filenames = tf.constant(list_sample_paths)
labels = tf.constant(list_labels)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(parse_function)
return dataset
training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)
tf.data.experimental.get_structure(training_dataset)
输出
TensorFlow: 2.2.0-rc2
(TensorSpec(shape=(4, 4, None), dtype=tf.float32, name=None), TensorSpec(shape=(),dtype=tf.int64, name=None))
推荐阅读
- c# - 在 .net 5.0 中使用用户名和密码向 Microsoft Dynamics/Azure AD 进行身份验证
- pandas - 使用两个 for 循环将字符串解析为数据帧
- python - 除了日志消息中的@everyone 角色(discord.py)
- grafana - 使用 Grafana 在 opentsdb 上使用直方图
- c# - 当您在asp net core mvc中有多合一列表时如何将特定用户添加到表中
- typescript - 如何在 amazon-cognito-identity-js 和 nest.js 中发送代码确认
- linux - 安装过程中 Docker 权限被拒绝
- flutter - Flutter pub 在客户端获取握手错误(操作系统错误:CERTIFICATE_VERIFY_FAILED:应用程序验证失败(handshake.cc:359))
- machine-learning - 将多个时间序列信号转换为一个频谱图
- angular-cli - 带有 eslint 的 Angular 项目超级慢