python - 为DocVQA数据集应用LayoutLMv2时提取特征时找到正确的开始和结束位置
问题描述
按照您的工作,我重现代码以在 docvqa 数据集上训练layoutlmv2。但是我在编码数据集时遇到了问题。特别是,该实现无法在从图像中提取的标记中找到答案的提取开始和结束位置。
def subfinder(words_list, answer_list):
matches = []
start_indices = []
end_indices = []
for idx, i in enumerate(range(len(words_list))):
if words_list[i] == answer_list[0] and words_list[i:i + len(answer_list)] == answer_list:
matches.append(answer_list)
start_indices.append(idx)
end_indices.append(idx + len(answer_list) - 1)
if len(matches) != 0:
return matches[0], start_indices[0], end_indices[0]
else:
return None, 0, 0
def read_ocr_annotation(file_path, shape):
words_img = []
boxes_img = []
width, height = shape
with open(file_path, 'r') as f:
data = json.load(f) # data = {"status": [], "recognitionResults": []}
try:
recognitionResults = data['recognitionResults']
# Loop through each recognition line
for reg_result in recognitionResults:
lines = reg_result['lines']
for line in lines:
for word_info in line['words']:
word_info['boundingBox'] = (word_info['boundingBox'])
x_min = np.min(word_info['boundingBox'][0:-1:2])
y_min = np.min(word_info['boundingBox'][1:-1:2])
x_max = np.max(word_info['boundingBox'][0:-1:2])
y_max = np.max(word_info['boundingBox'][1:-1:2])
words_img.append(word_info['text'])
boxes_img.append(normalize_bbox(bbox=[x_min, y_min, x_max, y_max],
width=reg_result['width'], height=reg_result['height']))
except:
if not 'WORD' in data.keys():
print("! Ignore ", file_path)
return [], []
for word in data['WORD']:
text = word['Text']
bbox = word['Geometry']['BoundingBox']
bbox = [bbox['Left']*width, bbox['Top']*height,
(bbox['Left'] + bbox['Width'])*width,
(bbox['Top'] + bbox['Height'])*height]
nl_bbox = normalize_bbox(bbox=bbox, width=width, height=height)
words_img.append(text)
boxes_img.append(nl_bbox)
return (words_img, boxes_img)
def encode_dataset(examples, max_length=512):
images = [Image.open(image_file).convert("RGB") for image_file in examples['image']]
org_shapes = [img.size[0:2] for img in images]
words = []
bbox = []
for i in range(len(images)):
words_img, boxes_img = read_ocr_annotation(file_path=examples['ocr_output_file'][i], shape=org_shapes[i])
words.append(words_img)
bbox.append(boxes_img)
questions = examples['question']
encoding = processor(images, questions, words, bbox, max_length=max_length, padding="max_length", truncation=True)
# next, add start_positions and end_positions
start_positions = [0]*BATCH_SIZE
end_positions = [0]*BATCH_SIZE
answers = examples['answers']
# for every example in the batch:
for idx in range(len(answers)):
cls_index = encoding.input_ids[idx].index(processor.tokenizer.cls_token_id)
words_example = [word.lower() for word in words[idx]]
for answer in answers[idx]:
match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
if match != None:
break
if match != None:
sequence_ids = encoding.sequence_ids(idx)
# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1
# End token index of the current span in the text.
token_end_index = len(encoding.input_ids[idx]) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
word_ids = encoding.word_ids(idx)[token_start_index:token_end_index+1]
for id in word_ids:
if id == word_idx_start:
start_positions[idx] = token_start_index
break
else:
token_start_index += 1
for id in word_ids[::-1]:
if id == word_idx_end:
end_positions[idx] = token_end_index
break
else:
token_end_index -= 1
else:
start_positions[idx] = cls_index
end_positions[idx] = cls_index
encoding['start_positions'] = start_positions
encoding['end_positions'] = end_positions
encoding['question_id'] = examples['questionId']
return encoding
您能否提供任何想法以尽可能地获得正确的注释?
解决方案
推荐阅读
- xml - 如何获取xml子节点中值的具体字数?
- github - 无法通过 Web UI 访问 github 版本的 repo
- windows - Flutter 运行时出错:java.lang.IllegalAccessError
- c++ - 为什么在将引用传递给类时这段代码会出现段错误?
- blueprint - 无法下载 mindmeld 蓝图“screening_app”
- reactjs - 在组合模型中将道具传递给孩子
- html - CSS:如何以一列覆盖另一列的方式组织两列?
- angular - 如何在茉莉花中对引导模式关闭功能进行单元测试 - Angular8
- javascript - 在 TypeScript 的过滤器和映射链中缩小对象属性的类型
- azure - 无法从 azure key vault 的 p12 文件中提取证书和密钥