首页 > 解决方案 > 使用 tf.data 使用 PIL 打开图像

问题描述

我目前正在尝试使用 tf.data 加载 VOC2012 数据集进行语义分割。VOC2012中的标签使用颜色图,如果我使用PIL库会自动转换。当我调用 tf.read_file 时,情况并非如此。

from PIL import Image

train_data = tf.data.Dataset.from_tensor_slices((img_filename_list, lbl_filename_list))

def preprocessing(img_filename, lbl_filename):
    # Load image
    train_img = tf.read_file(img_path + img_filename)
    train_img = tf.image.decode_jpeg(train_img, channels=3)
    train_img = train_img / 255.0  # Normalize

    return train_img, lbl_filename

train_data = train_data.map(preprocessing).shuffle(100).repeat().batch(2)
iterator = train_data.make_initializable_iterator()
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_data)

with tf.Session() as sess:
    sess.run(training_init_op)
    while True:
        train_images, lbl_filename = sess.run(next_element)

这就是我现在正在做的事情,尽管理想情况下,我希望预处理函数返回使用 PIL 加载的标签图像,这样我就可以创建单热向量。

def preprocessing(img_filename, lbl_filename):
    ...# Load train images
    train_lbl = Image.open(lbl_path + lbl_filename)
    ...# Do some other stuff
    return train_img, train_lbl

这会出错

AttributeError: 'Tensor' object has no attribute 'read'

有什么解决办法吗?

标签: pythontensorflow

解决方案


正如@GPhilo 所建议的,使用 tf.py_func 可以解决这个问题。这是我的解决方案代码

def read_labels(lbl_filename):
    train_lbl = Image.open(lbl_path + lbl_filename.decode("utf-8"))
    train_lbl = np.asarray(train_lbl)
    return train_lbl

def preprocessing(img_filename, lbl_filename):
    train_lbl = tf.py_func(read_labels, [lbl_filename], tf.uint8)

推荐阅读