首页 > 解决方案 > 为DocVQA数据集应用LayoutLMv2时提取特征时找到正确的开始和结束位置

问题描述

按照您的工作,我重现代码以在 docvqa 数据集上训练layoutlmv2。但是我在编码数据集时遇到了问题。特别是,该实现无法在从图像中提取的标记中找到答案的提取开始和结束位置。

def subfinder(words_list, answer_list):
    matches = []
    start_indices = []
    end_indices = []
    for idx, i in enumerate(range(len(words_list))):
        if words_list[i] == answer_list[0] and words_list[i:i + len(answer_list)] == answer_list:
            matches.append(answer_list)
            start_indices.append(idx)
            end_indices.append(idx + len(answer_list) - 1)
    if len(matches) != 0:
        return matches[0], start_indices[0], end_indices[0]
    else:
        return None, 0, 0


def read_ocr_annotation(file_path, shape):
    words_img = []
    boxes_img = []
    width, height = shape
    with open(file_path, 'r') as f:
        data = json.load(f)   # data = {"status": [], "recognitionResults": []}
        try:
            recognitionResults = data['recognitionResults']
            # Loop through each recognition line
            for reg_result in recognitionResults:
                lines = reg_result['lines']
                for line in lines:
                    for word_info in line['words']:
                        word_info['boundingBox'] = (word_info['boundingBox'])
                        x_min = np.min(word_info['boundingBox'][0:-1:2])
                        y_min = np.min(word_info['boundingBox'][1:-1:2])
                        x_max = np.max(word_info['boundingBox'][0:-1:2])
                        y_max = np.max(word_info['boundingBox'][1:-1:2])
                        words_img.append(word_info['text'])
                        boxes_img.append(normalize_bbox(bbox=[x_min, y_min, x_max, y_max], 
                            width=reg_result['width'], height=reg_result['height']))
        except:
            if not 'WORD' in data.keys():
                print("! Ignore ", file_path)
                return [], []
                
            for word in data['WORD']:
                text = word['Text']
                bbox = word['Geometry']['BoundingBox']
                bbox = [bbox['Left']*width, bbox['Top']*height, 
                        (bbox['Left'] + bbox['Width'])*width, 
                        (bbox['Top'] + bbox['Height'])*height]
                nl_bbox = normalize_bbox(bbox=bbox, width=width, height=height)
                words_img.append(text)
                boxes_img.append(nl_bbox)
    
    return (words_img, boxes_img)


def encode_dataset(examples, max_length=512):

    images         = [Image.open(image_file).convert("RGB") for image_file in examples['image']]
    org_shapes     = [img.size[0:2] for img in images]

    words          = []
    bbox           = []
    for i in range(len(images)):
        words_img, boxes_img = read_ocr_annotation(file_path=examples['ocr_output_file'][i], shape=org_shapes[i])
        words.append(words_img)
        bbox.append(boxes_img)

    questions  = examples['question']
    encoding   = processor(images, questions, words, bbox, max_length=max_length, padding="max_length", truncation=True)

    # next, add start_positions and end_positions
    start_positions = [0]*BATCH_SIZE
    end_positions   = [0]*BATCH_SIZE

    answers = examples['answers']
    
    # for every example in the batch:
    for idx in range(len(answers)):
        cls_index = encoding.input_ids[idx].index(processor.tokenizer.cls_token_id)

        words_example = [word.lower() for word in words[idx]]

        for answer in answers[idx]:
            match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
            if match != None:
                break
    
        if match != None:
            sequence_ids = encoding.sequence_ids(idx)
            
            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(encoding.input_ids[idx]) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            word_ids = encoding.word_ids(idx)[token_start_index:token_end_index+1]
            for id in word_ids:
                if id == word_idx_start:
                    start_positions[idx] = token_start_index
                    break
                else:
                    token_start_index += 1

            for id in word_ids[::-1]:
                if id == word_idx_end:
                    end_positions[idx] = token_end_index
                    break
                else:
                    token_end_index -= 1
        else:
            start_positions[idx] = cls_index
            end_positions[idx] = cls_index


    encoding['start_positions'] = start_positions
    encoding['end_positions']   = end_positions
    encoding['question_id']     = examples['questionId']

    return encoding

您能否提供任何想法以尽可能地获得正确的注释?

标签: pythonpytorchtransformerquestion-answering

解决方案


推荐阅读