首页 > 解决方案 > 如何为具有 yolo 输出层的多任务学习网络准备数据集?

问题描述

我有一个使用计算图的卷积神经网络,带有一个 Yolo 输出层和多个回归输出层(我只是将额外的输出层映射到一个典型的 Yolo CNN)。我遇到的问题是数据集,对于 Yolo 输出我有 pascal-voc.xml文件和回归输出.csv文件,即

物体检测数据

            InputSplit[] data = new FileSplit(image_dir, NativeImageLoader.ALLOWED_FORMATS, random).sample(path_filter, split_weight0,split_weight1);
            InputSplit trainData = data[0];
            InputSplit testData = data[1];

            ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels,gridH, gridW, new VocLabelProvider(DIR));
            ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels,gridH, gridW, new VocLabelProvider(DIR));
            recordReaderTrain.initialize(trainData);
            recordReaderTest.initialize(testData);

            //commented out since MultiDataSetIterator is meant to be used instead  
            //RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
            //RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
            //train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
            //test.setPreProcessor(new ImagePreProcessingScaler(0, 1)); 

回归数据

 RecordReader r0= new CSVRecordReader(',');
 r0.initialize(new FileSplit(new File( DIR+"/r0.csv")));

 RecordReader r1=...
 ...

我尝试从RecordReaderMultiDataSetIterator 示例 2中实现多任务学习示例

MultiDataSetIterator train = new RecordReaderMultiDataSetIterator.Builder(batchSize)
                .addReader("rr", recordReaderTrain)
                .addReader("r0", r0)
                .addInput("???")//recordReaderTrain.getInputImagesData() ?
                .addOutput("rr")//recordReaderTrain.getVocLabelData() ?
                .addOutput("r0")
                .addOutput("r1")
                ... //addOutput -> r2,r3,...
                .build();

如何正确配置来自VocLabelProvider数据预处理的记录读取器的输入,或者是否可以环绕另一个数据集迭代器,即来自对象检测和回归文件MultiDataSetIterator的两个数据集迭代器的多数据集迭代器.xml.csv

标签: javayolodeeplearning4jdl4jcomputation-graph

解决方案


推荐阅读