python - 用于图像和 numpy 文件的 TensorFlow 管道
问题描述
我正在使用 tensorflow 2.0.0 并尝试设置一个有效的管道来输入约 90,000 个大小为(256、256、3)的 png 图像及其标签,这些标签是用于图像分割问题的大小为 numpy 的数组(256,256)。这些图像和标签不会完全加载到内存中。
数据存储在这样的目录中:
'C:/Users/user/Documents/data/ims/' #png images
'C:/Users/user/Documents/data/masks/' #img labels/masks
保存扩展名的文件名相同,例如“test1.png”和“test1.npy”是图像/标签对。
数据尚未分为训练、验证和测试子集。
我需要将图像和标签分成训练、验证和测试子集,并且有一种方法可以将数据输入模型进行训练。
我在这里遵循本指南,但无法弄清楚如何处理 get_label 函数中的 numpy 文件。
我想我可以编写一个函数,仅通过文件名将数据拆分为子集,然后通过提供的文件名动态加载批处理,但我无法弄清楚如何有效地做到这一点。
我目前正在这样做,这要么不起作用,因为文件太大或太慢,因为有很多文件要加载到内存中,这两者都不是可行的解决方案。
import tensorflow as tf
import numpy as np
import glob2 as glob
from imageio import imread
base = '/mnt/projects/CNN_Data/clean_data/'
image_path = sorted(glob.glob(base + 'ims/*.png'))
label_path = sorted(glob.glob(base + 'masks/*.npy'))
images = [imread(img).astype(np.float32)/255.0 for img in image_path]
labels = [np.load(path) for path in label_path]
编辑添加:
这是我在上面链接的 tensorflow 示例之后的尝试。它运行,但我无法得到我想要的 get_label。
import tensorflow as tf
import numpy as np
import os
AUTOTUNE = tf.data.experimental.AUTOTUNE
base = '/mnt/projects/CNN_Data/clean_data/'
list_ds = tf.data.Dataset.list_files(base + 'ims/*')
def get_label(file_path):
parts = tf.strings.split(file_path, os.path.sep)
parts[-2] == 'masks'
fname = tf.strings.split(parts[-1], '.')[0]
fname = tf.strings.join([fname, '.npy'])
parts[-1] == fname
return parts
def decode_img(img):
img = tf.image.decode_png(img, channels = 3)
img = tf.image.convert_image_dtype(img, tf.float32)
return img
def process_path(file_path):
label = get_label(file_path)
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
解决方案
推荐阅读
- server - 为什么 Xampp 服务器提供内部服务器错误
- javascript - 如何安排每个季度的 api 调用?
- python - 无法从 LXML 获取标签
- python - 如何避免/消除表单字段上的完整性错误
- django - Django.session:如何确定 Django 用户会话的开始?
- laravel - 如何在雄辩的关系中显示具有至少一个属性的属性族?
- javascript - 如何将字符串数组从反应状态映射到选择框
- macos - Tomcat - 未正确关闭 - vjava.net.SocketException:错误的文件描述符 - macbook
- javascript - Javascript遍历类onmouseover
- javascript - 自动更改电子中的键盘输入