首页 > 技术文章 > 数据模块接口处理:读取不同的数据集

yuganwj 2020-09-24 14:57 原文

1.读取图片和XML文件

  和之前读取VOC_2007文件过程相同,只不过商品数据集对应只有8个类别,修改dataset.config中的参数就可以

# 商品数据集的类别
VOC_LABELS = {
     'none': (0, 'Background'),
     'clothes': (1, 'clothes'),
     'pants': (2, 'pants'),
     'shoes': (3, 'shoes'),
     'watch': (4, 'watch'),
     'phone': (5, 'phone'),
     'audio': (6, 'audio'),
     'computer': (7, 'computer'),
     'books': (8, 'books')
}

 

2.将图片和XML文件序列化后,读取该序列化文件

(1)设计一个读取数据的基类

dataset_utils.py
class TFRecordReaderBase(object):
'''
数据集读取基类,反序列化
'''
def __init__(self, param):
# param给不同数据集的属性配置
self.param = param
def get_data(self, train_or_test, dataset_dir):
'''
获取数据规范
:param train_or_test: train_or_test数据文件
:param dataset_dir:数据集目录
:return:
'''
return None

  (2)   继承基类,处理不同的数据集,其中的参数在dataset_config中进行设置

import os
import tensorflow as tf
slim = tf.contrib.slim
from datasets.utils import dataset_utils

class CommidityTFRecords(dataset_utils.TFRecordReaderBase):
'''
商品数据集读取类
'''
def __init__(self, param):
self.param = param
def get_data(self, train_or_test, dataset_dir):
'''
获取数据方法
'''
# 异常抛出
if train_or_test not in ['train', 'test']:
raise ValueError('训练/测试的名字 %s 指定错误'%train_or_test)
# 判断数据集目录
if not tf.gfile.Exists(dataset_dir):
raise ValueError('数据集目录不存在')


# 准备参数
# 第一个参数:数据目录+文件名
file_pattern = os.path.join(dataset_dir, self.param.FILE_PATTERN % train_or_test)
# 第二个参数 reader
reader = tf.TFRecordReader
# 第三个参数 decoder
# 反序列化
keys_to_features = {
'image/height': tf.FixedLenFeature([1], tf.int64),
'image/width': tf.FixedLenFeature([1], tf.int64),
'image/channels': tf.FixedLenFeature([1], tf.int64),
'image/shape': tf.FixedLenFeature([3], tf.int64),
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/encoded': tf.FixedLenFeature((), tf.string, default_value='')
}
# 反序列化为高级的形式,用户可以直接使用的形式
items_to_handlers = {
'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
'shape': slim.tfexample_decoder.Tensor('image/shape'),
'object/bbox': slim.tfexample_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
'object/labels': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated')
}
# 构造decoder
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=self.param.SPLITS_TO_SIZES[train_or_test],
items_to_descriptions=self.param.ITEMS_TO_DESCRIPTIONS, # 数据集返回的格式描述字典
num_classes=self.param.NUM_CLASSES

)


'''
数据集读取配置
'''
# 创建命名字典
DataSetParams = namedtuple('DataSetParameters', ['FILE_PATTERN',
'NUM_CLASSES',
'SPLITS_TO_SIZES',
'ITEMS_TO_DESCRIPTIONS'])
# 定义commidity数据属性配置
Cm2018 = DataSetParams(
FILE_PATTERN = 'VOC_2007_%s_*.tfrecord',
NUM_CLASSES = 20,
SPLITS_TO_SIZES = {
'train': 100,
'test': 0
},
ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying height and width',
'shape': 'Shape of image',
'object/bbox': 'A list of bounding boxes, one per each object',
'object/label': 'A list of labels, one per each object'
}
)

(3)设置读取工厂,提供给外部人员调用

from datasets.dataset_init import commidity_2018
from datasets.dataset_config import Cm2018

datasetsmap = {
    'commidity_2018': commidity_2018.CommidityTFRecords
}
# 逻辑:
#1.数据集名称
#2.指定训练或者测试
#3.数据集目录指定
def get_dataset(dataset_name, train_or_test, dataset_dir):
    '''
    获取不同的数据
    :param dataset_name:数据集名称
    :param train_or_test: 训练或测试
    :param dataset_dir: 数据集目录
    :return: Dataset 数据规范
    '''
    if dataset_name not in datasetsmap:
        raise ValueError('输入的数据集名称%s不存在'%dataset_name)
    return datasetsmap[dataset_name](Cm2018).get_data(train_or_test, dataset_dir)

  (4)  通过读取工厂,即可读取不同的数据集

from datasets.dataset_init import passvalvoc_2007
import dataset_factory
import tensorflow as tf

slim = tf.contrib.slim

if __name__ == '__main__':
# 获取dataset
dataset = dataset_factory.get_dataset('commidity_2018', 'train', './DATA/tfrecord/commidity_TFRecords/')
# 通过provider取出数据
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers = 3 # 读取线程数
)
# 通过get方法获取指定名称的数据(是在准备规范数据dataset时高级格式的名称)
[image, shape, bbox, label, diff, trunc] = provider.get(
['image', 'shape', 'object/bbox', 'object/labels', 'object/difficult', 'object/truncated']
)
print(image, shape, bbox, label, diff, trunc)









推荐阅读