首页 > 技术文章 > 中文语义角色标注

little-horse 2021-12-31 21:37 原文

albert-crf for SRL(Semantic Role Labeling),中文语义角色标注

项目地址:https://github.com/jiangnanboy/albert_srl

概述

  自然语言的语义理解往往包括分析构成一个事件的行为、施事、受事等主要元素,以及其他附属元素(adjuncts),例如事件发生的时间、地点、方式等。在事件语义学(Event semantics)中,构成一个事件的各个元素被称为语义角色(Semantic Roles);而语义角色标注(Semantic Role Labeling)任务就是识别出一个句子中所有的事件及其组成元素的过程,例如:其行为(往往是句子中的谓词部分),施事,事件,地点等。下图中,例子中的一个事件“生产”被识别出来,而“生产”对应的施事“全行业”和受事“化肥二千七百二十万吨”以及附属的事件发生的时间“全年”被准确标注出来。语义角色标注可为许多下游任务提供支持,例如:更深层的语义分析(AMR Parsing,CCG Parsing等),任务型对话系统中的意图识别,事实类问答系统中的实体打分等。

 

语义角色标注标签集合:

Core semantic role Arg0 Agent, Experiencer
Arg1 Theme, Topic, Patient
Arg2 Recipient, Extent, Predicate
Arg3 Asset, Theme2, Recipient
Arg4 Beneficiary
Arg5 Destination
Adjunctive semantic role ArgM-ADV Adverbials(附加的)
ArgM-BNE Beneficiary(受益者)
ArgM-CND Condition(条件)
ArgM-DIR Direction(方向)
ArgM-DGR Degree(程度)
ArgM-EXT Extent(延展)
ArgM-TMP Temporal(时间)
ArgM-TPC Topic(主题)
ArgM-PRP Purpose or Reason(目的或原因)
ArgM-FRQ Frequency(频率)
ArgM-LOC Locative(方位)
ArgM-MNR Manner(方式)

方法

利用huggingface/transformers中的albert+crf进行中文语义角色标注

利用albert加载中文预训练模型,后接一个前馈分类网络,最后接一层crf。利用albert预训练模型进行fine-tune。

整个流程是:

  • 数据经albert后获取最后的隐层hidden_state=768
  • 将last hidden_state=768和谓词位置指示器(predicate_indicator)进行concatenation,最后维度是(768 + 1)经一层前馈网络进行分类
  • 将前馈网络的分类结果输入crf

 

数据说明

BIOES形式标注,见data/

训练数据示例如下,其中各列为是否语义谓词角色,每句仅有一个谓语动词为语义谓词,即每句中第二列取值为1的是谓词,其余都为0.

她 0 O
介 0 O
绍 0 O
说 0 O
, 0 O
全 0 B-ARG0
行 0 I-ARG0
业 0 E-ARG0
全 0 B-ARGM-TMP
年 0 E-ARGM-TMP
生 1 B-REL
产 1 E-REL
化 0 B-ARG1
肥 0 I-ARG1
二 0 I-ARG1
千 0 I-ARG1
七 0 I-ARG1
百 0 I-ARG1
二 0 I-ARG1
十 0 I-ARG1
万 0 I-ARG1
吨 0 E-ARG1

训练和预测见(examples/test_srl.py)

    srl = SRL(args)
    if train_bool(args.train):
        srl.train()
        '''
        epoch: 45, acc_loss: 46.13663248741068
        dev_score: 0.931195765914934
        val_loss: 50.842400789260864, best_val_loss: 50.84240078926086
        '''
    else:
        srl.load()
        # ner.test(args.test_path)
        text = '代表朝方对中国党政领导人和人民哀悼金日成主席逝世表示深切谢意'
        predicates_indicator = [0, 0, 0, 0, 0, 0, 0, 0, 0,0,0,0,0,0,0,1,1, 0, 0, 0, 0, 0, 0, 0,0,0,0,0,0,0]
        assert len(text) == len(predicates_indicator)
        pprint(srl.predict(text, predicates_indicator))    
        # tag_predict: ['O', 'O', 'O', 'O', 'O', 'B-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'E-ARG0', 'B-REL', 'E-REL', 'B-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'E-ARG1', 'O', 'O', 'O', 'O', 'O', 'O']
        # {'ARG0': '中国党政领导人和人民', 'ARG1': '金日成主席逝世', 'REL': '哀悼'}

输入数据的处理

项目中srl/dataset主要是将数据处理成模型所需数据的输入格式:

这里需要注意的是predicates和label_mask的生产方式和作用。predicates主要是为后面albert的输出进行cat,而label_mask不同于attention_mask,它是将“谓词[SEP]”进行mask进入crf层。

    def __getitem__(self, idx):
        words, tags, predicates = self.data_list[idx]
        input = ''.join(words)
        predicates_non_index = np.nonzero(predicates)[0]
        predicate = words[predicates_non_index[0] : predicates_non_index[-1] + 1]
        predicate_tag = tags[predicates_non_index[0] : predicates_non_index[-1] + 1]
        predicate = ''.join(predicate)
        input = input + self.SPECIAL_TOKENS['sep_token'] + predicate # [cls] + input + [sep] + predicate + [sep]

        n_pad = self.max_length - (len(tags) + len(predicate_tag)) # padding or truncation for label
        if n_pad < 0:
            tags = tags[:len(tags) - (abs(n_pad) + 3)]
            tags = [self.SPECIAL_TOKENS['cls_token']] + tags + [self.SPECIAL_TOKENS['sep_token']] + predicate_tag + [
                self.SPECIAL_TOKENS['sep_token']]
            predicates = predicates[:len(predicates) - (abs(n_pad) + 3)]
            predicates = [0] + predicates + [0] + [0] * len(predicate_tag) + [0]
        elif n_pad == 0:
            tags = tags[:len(tags) - 3]
            tags = [self.SPECIAL_TOKENS['cls_token']] + tags + [self.SPECIAL_TOKENS['sep_token']] + predicate_tag + [self.SPECIAL_TOKENS['sep_token']]
            predicates = predicates[:len(predicates) - 3]
            predicates = [0] + predicates + [0] + [0] * len(predicate_tag) + [0]
        elif n_pad == 1:
            tags = tags[:len(tags) - 2]
            tags = [self.SPECIAL_TOKENS['cls_token']] + tags+ [self.SPECIAL_TOKENS['sep_token']] + predicate_tag + [self.SPECIAL_TOKENS['sep_token']]
            predicates = predicates[:len(predicates) - 2]
            predicates = [0] + predicates + [0] + [0] * len(predicate_tag) + [0]
        elif n_pad == 2:
            tags = tags[:len(tags) - 1]
            tags = [self.SPECIAL_TOKENS['cls_token']] + tags+ [self.SPECIAL_TOKENS['sep_token']] + predicate_tag + [self.SPECIAL_TOKENS['sep_token']]
            predicates = predicates[:len(predicates) - 1]
            predicates = [0] + predicates + [0] + [0] * len(predicate_tag) + [0]
        else:
            tags = [self.SPECIAL_TOKENS['cls_token']] + tags + [self.SPECIAL_TOKENS['sep_token']] + predicate_tag + [self.SPECIAL_TOKENS['sep_token']]
            tags.extend([self.SPECIAL_TOKENS['pad_token']] * (self.max_length - len(tags)))
            predicates = [0] + predicates + [0] + [0] * len(predicate_tag) + [0]
            predicates.extend([0] * (self.max_length - len(predicates)))

        label = [self.label2i[token] for token in tags]
        label_mask = [1] * len(tags[:tags.index('[SEP]') + 1])
        label_mask.extend([0] * len(tags[tags.index('[SEP]') + 1:]))
        encodings_dict = self.tokenizer(input,
                                        truncation='only_first',
                                        max_length=self.max_length,
                                        padding='max_length')

        input_ids = encodings_dict['input_ids']
        token_type_ids = encodings_dict['token_type_ids']
        attention_mask = encodings_dict['attention_mask']

        return {'label': torch.tensor(label),
                'predicates': torch.FloatTensor(predicates),
                'input_ids': torch.tensor(input_ids),
                'token_type_ids': torch.tensor(token_type_ids),
                'attention_mask': torch.tensor(attention_mask),
                'label_mask': torch.tensor(label_mask)} # mask loss

  






推荐阅读