首页 > 解决方案 > 在张量流对象检测中训练自己的数据集(BDD100K)时的数据转换错误

问题描述

tensorflow /models/research/object_detection
中的问题 将图像及其注释转换为 TFRecord 文件的主要代码如下:

writer = tf.python_io.TFRecordWriter(output_path)
img_w, img_h = 1280, 720
with open(json_path, 'rb') as json_f:
    items = ijson.items(json_f, 'item')
    img_counter = 0
    skiped_img_counter = 0
    for item in items: # item is a dict, which contains a jpg image and its labels etc.
        img_counter += 1
        img_name = item['name']
        xmins = []
        ymins = []
        xmaxs = []
        ymaxs = []
        classes = []
        labels = []
        occluded = []
        truncated = []

        labels_ = item['labels']
        for label in labels_:
            category = label['category']
            if category in categories:
                nums[category] += 1
                labels.append(label_id[category])
                classes.append(category.encode('utf8'))
                att_ = label['attributes']
                occluded.append(int(att_['occluded'] == 'true'))
                truncated.append(int(att_['truncated'] == 'true'))
                box2d = label['box2d']
                xmins.append(float(box2d['x1'])/img_w) 
                ymins.append(float(box2d['y1'])/img_h) 
                xmaxs.append(float(box2d['x2'])/img_w)
                ymaxs.append(float(box2d['y2'])/img_h)
        difficult_obj = [0] * len(xmins)
        if 0 == len(xmins):
            skiped_img_counter += 1
            print("{0} has no object, skip it and continue.".format(img_name))
            continue
        assert len(xmins) == len(labels) == len(classes) == len(difficult_obj) == len(occluded) == len(truncated), 'not same list length'
        img_path = os.path.join(img_folder, img_name)
        with tf.gfile.GFile(img_path, 'rb') as fid:
            encoded_jpg = fid.read()
        key = hashlib.sha256(encoded_jpg).hexdigest()
        # att = item['attributes']
        # weather, scene, timeofday = att['weather'], att['scene'], att['timeofday']
        tf_example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(img_h),
            'image/width': int64_feature(img_w),
            'image/filename': bytes_feature(img_name.encode('utf8')),
            'image/source_id': bytes_feature(img_name.encode('utf8')),
            'image/key/sha256': bytes_feature(key.encode('utf8')),
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode('utf8')),
            'image/object/bbox/xmin': float_list_feature(xmins),
            'image/object/bbox/xmax': float_list_feature(xmaxs),
            'image/object/bbox/ymin': float_list_feature(ymins),
            'image/object/bbox/ymax': float_list_feature(ymaxs),
            'image/object/bbox/text': bytes_list_feature(classes),
            'image/object/bbox/label': int64_list_feature(labels),
            'image/object/bbox/difficult': int64_list_feature(difficult_obj),
            'image/object/bbox/occluded': int64_list_feature(occluded),
            'image/object/bbox/truncated': int64_list_feature(truncated),
        }))
        print(img_name, 'precessed.')
        writer.write(tf_example.SerializeToString())
    print('{0} images were processed and {1} were skipped.'.format(img_counter, skiped_img_counter))
    print(nums)
    writer.close()

错误如下:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0] = 0 is not in [0, 0) [[{{node GatherV2_2}} = GatherV2[Taxis=DT_INT32, Tindices=DT_INT64, Tparams=DT_INT64, _device=" /device:CPU:0"](cond_1/Merge, Reshape_8, GatherV2_1/axis)]] [[{{node IteratorGetNext}} = IteratorGetNextoutput_shapes=[[8], [8,300,300,3], [8,2], [ 8,3], [8,100], [8,100,4], [8,100,7], [8,100,7], [8,100], [8,100], [8,100], [8]], output_types=[DT_INT32, DT_FLOAT , DT_INT32, DT_INT32, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32, DT_BOOL, DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

标签: python-3.xtensorflowobject-detectiontfrecord

解决方案


tf.train.Example 中的功能

'image/object/bbox/text': bytes_list_feature(classes),
'image/object/bbox/label': int64_list_feature(labels),
'image/object/bbox/difficult': int64_list_feature(difficult_obj),
'image/object/bbox/occluded': int64_list_feature(occluded),
'image/object/bbox/truncated': int64_list_feature(truncated),

应该:

'image/object/class/text': bytes_list_feature(classes),
'image/object/class/label': int64_list_feature(labels),
'image/object/difficult': int64_list_feature(difficult_obj),
'image/object/occluded': int64_list_feature(occluded),
'image/object/truncated': int64_list_feature(truncated),   

这是由于我的粗心造成的错误,示例功能中的所有键都应该在core.standard_field.TfExampleFields


推荐阅读