tensorflow - Tensorflow 数据集地图功能问题
问题描述
我正在使用 tensorflow 2.1 来构建数据管道。我写了一个函数来做数据预处理:
def preprocessing(path):
path = str(path.numpy(), 'utf-8')
label = Path(path).parent.name
image = tf.io.read_file(path)
image = tf.image.decode_image(image)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.central_crop(image, central_fraction=0.5)
image = tf.image.resize(image, size=[224, 224])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
return image, label
当我使用以下代码验证处理功能时,它可以工作。
ds = tf.data.Dataset.list_files('../datasets/hymenoptera_data/train/ants/*.jpg')
path = next(iter(ds))
image, label = preprocessing(path)
plt.imshow(image)
plt.show()
print(path) 的结果是 tf.Tensor(b'..\datasets\hymenoptera_data\train\ants\886401651_f878e888cd.jpg', shape=(), dtype=string) 但是如果我使用 map() 来处理生成ds,错误出来了:
ds_new = ds.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
for i in ds_new.take(1):
plt.imshow(i)
plt.show()
AttributeError: 'Tensor' 对象没有属性 'numpy',由于预处理函数中的 path = str(path.numpy(), 'utf-8') 发生此错误。
我不明白为什么,谁能帮助解决这个问题,非常感谢!
解决方案
试试这个函数进行预处理:
def preprocessing(path):
label = tf.strings.split(path, os.path.sep)[-2]
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.central_crop(image, central_fraction=0.5)
image = tf.image.resize(image, size=[224, 224])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
return image, label
适用于普通加载和tf.data
:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
paths = tf.data.Dataset.list_files('images/*.jpg')
path = next(iter(paths))
image, label = preprocessing(path)
plt.imshow(image)
plt.show()
filenames = tf.data.Dataset.list_files('images/*.jpg')
ds = filenames.map(preprocessing)
for image, label in ds.take(1):
plt.imshow(image)
plt.show()
推荐阅读
- javascript - 为什么我的 AJAX 请求会打开我的 PHP 文件,而不是只返回应有的响应文本?
- java - Kafka消费者 - 为什么在心跳线程中记录偏移重置?
- react-native - 当 Modal 显示时,React-Native 标题分别动画到 View
- php - 在 Woocommerce 我的帐户下载部分显示产品图片
- c# - 评估引用外部数据的布尔表达式
- google-apps-script - 如何通过从谷歌表格单元格中插入的链接获取其 URL,将谷歌驱动器文件附加到电子邮件
- javascript - 如何使用 jquery 旋钮通过 laravel 中的 pusher 绑定来自输入元素的更新值?
- node.js - verdaccio 错误:413 Payload Too Large - PUT 请求实体太大
- hibernate - 当策略为 IDENTITY 时,休眠 RX 插入和刷新和刷新返回异常
- machine-learning - 通常如何对 RNN/LSTM 的序列数据执行批处理