首页 > 技术文章 > 【代码精读】Graph Convolution over Pruned Dependency Trees for Relation Extraction(2)

Harukaze 2021-01-12 23:34 原文

Training

To train a graph convolutional neural network (GCN) model, run:

1 bash train_gcn.sh 0
2 
3 SAVE_ID=$1
4 python train.py --id $SAVE_ID --seed 0 --prune_k 1 --lr 0.3 --no-rnn --num_epoch 100 --pooling max --mlp_layers 2 --pooling_l2 0.003

Model checkpoints and logs will be saved to ./saved_models/00.

To train a Contextualized GCN (C-GCN) model, run:

1 bash train_cgcn.sh 1
2 
3 SAVE_ID=$1
4 python train.py --id $SAVE_ID --seed 0 --prune_k 1 --lr 0.3 --rnn_hidden 200 --num_epoch 100 --pooling max --mlp_layers 2 --pooling_l2 0.003

Model checkpoints and logs will be saved to ./saved_models/01.

For details on the use of other parameters, such as the pruning distance k, please refer to train.py.

可以看到C-GCN与GCN仅仅在RNN参数处存在区别,C-GCN多了RNN层。

下面阅读train.py主文件:

 1 parser = argparse.ArgumentParser()
 2 parser.add_argument('--data_dir', type=str, default='dataset/tacred')
 3 parser.add_argument('--vocab_dir', type=str, default='dataset/vocab')
 4 parser.add_argument('--emb_dim', type=int, default=300, help='Word embedding dimension.')
 5 parser.add_argument('--ner_dim', type=int, default=30, help='NER embedding dimension.')
 6 parser.add_argument('--pos_dim', type=int, default=30, help='POS embedding dimension.')
 7 parser.add_argument('--hidden_dim', type=int, default=200, help='RNN hidden state size.')
 8 parser.add_argument('--num_layers', type=int, default=2, help='Num of RNN layers.')
 9 parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.')
10 parser.add_argument('--gcn_dropout', type=float, default=0.5, help='GCN layer dropout rate.')
11 parser.add_argument('--word_dropout', type=float, default=0.04, help='The rate at which randomly set a word to UNK.')
12 parser.add_argument('--topn', type=int, default=1e10, help='Only finetune top N word embeddings.')
13 parser.add_argument('--lower', dest='lower', action='store_true', help='Lowercase all words.')
14 parser.add_argument('--no-lower', dest='lower', action='store_false')
15 parser.set_defaults(lower=False) #set_defaults()可以设置一些参数的默认值
16 
17 parser.add_argument('--prune_k', default=-1, type=int, help='Prune the dependency tree to <= K distance off the dependency path; set to -1 for no pruning.')
18 parser.add_argument('--conv_l2', type=float, default=0, help='L2-weight decay on conv layers only.')
19 parser.add_argument('--pooling', choices=['max', 'avg', 'sum'], default='max', help='Pooling function type. Default max.')
20 parser.add_argument('--pooling_l2', type=float, default=0, help='L2-penalty for all pooling output.')
21 parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
22 parser.add_argument('--no_adj', dest='no_adj', action='store_true', help="Zero out adjacency matrix for ablation.")#取零邻接矩阵进行消融
23 
24 parser.add_argument('--no-rnn', dest='rnn', action='store_false', help='Do not use RNN layer.') #dest相当于为参数起别名
25 parser.add_argument('--rnn_hidden', type=int, default=200, help='RNN hidden state size.')
26 parser.add_argument('--rnn_layers', type=int, default=1, help='Number of RNN layers.')
27 parser.add_argument('--rnn_dropout', type=float, default=0.5, help='RNN dropout rate.')
28 
29 parser.add_argument('--lr', type=float, default=1.0, help='Applies to sgd and adagrad.')
30 parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate decay rate.')
31 parser.add_argument('--decay_epoch', type=int, default=5, help='Decay learning rate after this epoch.')
32 parser.add_argument('--optim', choices=['sgd', 'adagrad', 'adam', 'adamax'], default='sgd', help='Optimizer: sgd, adagrad, adam or adamax.')
33 parser.add_argument('--num_epoch', type=int, default=100, help='Number of total training epochs.')
34 parser.add_argument('--batch_size', type=int, default=50, help='Training batch size.')
35 parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')#防止梯度爆炸
36 parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
37 parser.add_argument('--log', type=str, default='logs.txt', help='Write training log to file.')
38 parser.add_argument('--save_epoch', type=int, default=100, help='Save model checkpoints every k epochs.')
39 parser.add_argument('--save_dir', type=str, default='./saved_models', help='Root dir for saving models.')
40 parser.add_argument('--id', type=str, default='00', help='Model ID under which to save models.')
41 parser.add_argument('--info', type=str, default='', help='Optional info for the experiment.')
42 
43 parser.add_argument('--seed', type=int, default=1234)
44 parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
45 parser.add_argument('--cpu', action='store_true', help='Ignore CUDA.')
46 
47 parser.add_argument('--load', dest='load', action='store_true', help='Load pretrained model.')
48 parser.add_argument('--model_file', type=str, help='Filename of the pretrained model.')
49 
50 args = parser.parse_args()

一些CPU/GPU设置:

 1 torch.manual_seed(args.seed) #为CPU设置种子用于生成随机数,以使得结果是确定的
 2 np.random.seed(args.seed)
 3 random.seed(1234)
 4 if args.cpu:
 5     args.cuda = False
 6 elif args.cuda:
 7     torch.cuda.manual_seed(args.seed) #为当前GPU设置随机种子
 8 init_time = time.time()
 9 
10 # make opt
11 opt = vars(args) #把args中的参数像字典一样,键引用值,当然,参数名和所对应的值用的是个"不可见"的字典。我们可以使用 vars 函数来返回这个字典:
12 label2id = constant.LABEL_TO_ID   #关系对应的id (1+41)
13 opt['num_class'] = len(label2id)

加载vocab:

1 # load vocab
2 vocab_file = opt['vocab_dir'] + '/vocab.pkl'
3 vocab = Vocab(vocab_file, load=True)
4 opt['vocab_size'] = vocab.size
5 emb_file = opt['vocab_dir'] + '/embedding.npy'
6 emb_matrix = np.load(emb_file)
7 assert emb_matrix.shape[0] == vocab.size
8 assert emb_matrix.shape[1] == opt['emb_dim']

 其中Vocab( )类:

 1 class Vocab(object):
 2     def __init__(self, filename, load=False, word_counter=None, threshold=0):
 3         if load:
 4             assert os.path.exists(filename), "Vocab file does not exist at " + filename
 5             # load from file and ignore all other params
 6             self.id2word, self.word2id = self.load(filename)
 7             self.size = len(self.id2word)
 8             print("Vocab size {} loaded from file".format(self.size))
 9         else:
10             print("Creating vocab from scratch...")
11             assert word_counter is not None, "word_counter is not provided for vocab creation."
12             self.word_counter = word_counter
13             if threshold > 1:
14                 # remove words that occur less than thres
15                 self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold])
16             self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True)
17             # add special tokens to the beginning
18             self.id2word = [constant.PAD_TOKEN, constant.UNK_TOKEN] + self.id2word
19             self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))])
20             self.size = len(self.id2word)
21             self.save(filename)
22             print("Vocab size {} saved to file {}".format(self.size, filename))
23 
24     def load(self, filename):
25         with open(filename, 'rb') as infile: #rb: read binary 以可读二进制方式打开 wb: write binary 以可写二进制方式打开
26             id2word = pickle.load(infile)
27             word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))])
28         return id2word, word2id
29 
30     def save(self, filename):
31         if os.path.exists(filename):
32             print("Overwriting old vocab file at " + filename)
33             os.remove(filename)
34         with open(filename, 'wb') as outfile:
35             pickle.dump(self.id2word, outfile)
36         return

 训练集与验证集获取:

1 # load data
2 print("Loading data from {} with batch size {}...".format(opt['data_dir'], opt['batch_size']))
3 train_batch = DataLoader(opt['data_dir'] + '/train.json', opt['batch_size'], opt, vocab, evaluation=False)
4 dev_batch = DataLoader(opt['data_dir'] + '/dev.json', opt['batch_size'], opt, vocab, evaluation=True)

 其中DataLoader( ):

  1 class DataLoader(object):
  2     """
  3     Load data from json files, preprocess and prepare batches.
  4     """
  5     def __init__(self, filename, batch_size, opt, vocab, evaluation=False):
  6         self.batch_size = batch_size
  7         self.opt = opt
  8         self.vocab = vocab
  9         self.eval = evaluation
 10         self.label2id = constant.LABEL_TO_ID
 11 
 12         with open(filename) as infile:
 13             data = json.load(infile)
 14         self.raw_data = data   #json文件数据,生肉
 15         data = self.preprocess(data, vocab, opt)
 16 
 17         # shuffle for training
 18         if not evaluation:
 19             indices = list(range(len(data)))
 20             random.shuffle(indices)
 21             data = [data[i] for i in indices]
 22         self.id2label = dict([(v,k) for k,v in self.label2id.items()])
 23         self.labels = [self.id2label[d[-1]] for d in data]
 24         #preprocess后的数据 d[-1]为关系类型对应的id,这里都取出来转回关系单词成一个集合
 25         self.num_examples = len(data)
 26 
 27         # chunk into batches
 28         data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
 29         self.data = data
 30         print("{} batches created for {}".format(len(data), filename))
 31 
 32     def preprocess(self, data, vocab, opt):
 33         """ Preprocess the data and convert to ids. """
 34         processed = []
 35         for d in data:
 36             tokens = list(d['token'])
 37             if opt['lower']:
 38                 tokens = [t.lower() for t in tokens]
 39             # anonymize tokens
 40             ss, se = d['subj_start'], d['subj_end']
 41             os, oe = d['obj_start'], d['obj_end']
 42             tokens[ss:se+1] = ['SUBJ-'+d['subj_type']] * (se-ss+1)   #标签替代实体处的单词
 43             tokens[os:oe+1] = ['OBJ-'+d['obj_type']] * (oe-os+1)
 44             tokens = map_to_ids(tokens, vocab.word2id)
 45             pos = map_to_ids(d['stanford_pos'], constant.POS_TO_ID)   #词性标签->id
 46             ner = map_to_ids(d['stanford_ner'], constant.NER_TO_ID)   #命名实体类型标签->id
 47             deprel = map_to_ids(d['stanford_deprel'], constant.DEPREL_TO_ID)  #依赖树标签->id
 48             head = [int(x) for x in d['stanford_head']]                 #???
 49             assert any([x == 0 for x in head])
 50             l = len(tokens)
 51             subj_positions = get_positions(d['subj_start'], d['subj_end'], l)  #返回距离List :[-3,-2,-1,0,0,0,1,2,3]
 52             obj_positions = get_positions(d['obj_start'], d['obj_end'], l)     #0,0,0处是实体
 53             subj_type = [constant.SUBJ_NER_TO_ID[d['subj_type']]]   #主语实体类型->id
 54             obj_type = [constant.OBJ_NER_TO_ID[d['obj_type']]]
 55             relation = self.label2id[d['relation']]             #关系类型->id
 56             processed += [(tokens, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, relation)]
 57         return processed
 58 
 59     def gold(self):
 60         """ Return gold labels as a list. """
 61         return self.labels
 62 
 63     def __len__(self):
 64         return len(self.data)
 65 
 66     def __getitem__(self, key):
 67         """ Get a batch with index. """
 68         if not isinstance(key, int):
 69             raise TypeError
 70         if key < 0 or key >= len(self.data):
 71             raise IndexError
 72         batch = self.data[key]
 73         batch_size = len(batch)   # key取几个数,则batch_size为几
 74         batch = list(zip(*batch))
 75         assert len(batch) == 10
 76 
 77         # sort all fields by lens for easy RNN operations
 78         lens = [len(x) for x in batch[0]]    #每个batch中对应的单词数目
 79         batch, orig_idx = sort_all(batch, lens)  #按照batch内的单词从多到少顺序返回batch,
 80 
 81         # word dropout
 82         if not self.eval:
 83             words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
 84         else:
 85             words = batch[0]
 86 
 87         # convert to tensors
 88         words = get_long_tensor(words, batch_size)
 89         masks = torch.eq(words, 0)  #words 矩阵中pad的位置为0,这些位置因为=0所以被mask为1
 90         pos = get_long_tensor(batch[1], batch_size)
 91         ner = get_long_tensor(batch[2], batch_size)
 92         deprel = get_long_tensor(batch[3], batch_size)
 93         head = get_long_tensor(batch[4], batch_size)
 94         subj_positions = get_long_tensor(batch[5], batch_size)
 95         obj_positions = get_long_tensor(batch[6], batch_size)
 96         subj_type = get_long_tensor(batch[7], batch_size)
 97         obj_type = get_long_tensor(batch[8], batch_size)
 98 
 99         rels = torch.LongTensor(batch[9])
100 
101         return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx)
102 
103     def __iter__(self):
104         for i in range(self.__len__()):
105             yield self.__getitem__(i)
106 
107 def map_to_ids(tokens, vocab):
108     ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
109     return ids
110 
111 def get_positions(start_idx, end_idx, length):
112     """ Get subj/obj position sequence. """
113     return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \
114             list(range(1, length-end_idx))
115 
116 def get_long_tensor(tokens_list, batch_size):
117     """ Convert list of list of tokens to a padded LongTensor. """
118     token_len = max(len(x) for x in tokens_list)
119     tokens = torch.LongTensor(batch_size, token_len).fill_(constant.PAD_ID)
120     for i, s in enumerate(tokens_list):
121         tokens[i, :len(s)] = torch.LongTensor(s)
122     return tokens
123 
124 def sort_all(batch, lens):
125     """ Sort all fields by descending order of lens, and return the original indices. """
126     unsorted_all = [lens] + [range(len(lens))] + list(batch)
127     sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
128     #sorted没指定key,默认按第一维度作为索引排序
129     return sorted_all[2:], sorted_all[1]
130 
131 def word_dropout(tokens, dropout):
132     """ Randomly dropout tokens (IDs) and replace them with <UNK> tokens. """
133     return [constant.UNK_ID if x != constant.UNK_ID and np.random.random() < dropout \
134             else x for x in tokens]

第45行词性标签不同词性对应的含义:

 1 POS tag list:
 2 
 3 CC  coordinating conjunction
 4 CD  cardinal digit
 5 DT  determiner
 6 EX  existential there (like: "there is" ... think of it like "there exists")
 7 FW  foreign word
 8 IN  preposition/subordinating conjunction
 9 JJ  adjective   'big'
10 JJR adjective, comparative  'bigger'
11 JJS adjective, superlative  'biggest'
12 LS  list marker 1)
13 MD  modal   could, will
14 NN  noun, singular 'desk'
15 NNS noun plural 'desks'
16 NNP proper noun, singular   'Harrison'
17 NNPS    proper noun, plural 'Americans'
18 PDT predeterminer   'all the kids'
19 POS possessive ending   parent's
20 PRP personal pronoun    I, he, she
21 PRP$    possessive pronoun  my, his, hers
22 RB  adverb  very, silently,
23 RBR adverb, comparative better
24 RBS adverb, superlative best
25 RP  particle    give up
26 TO  to  go 'to' the store.
27 UH  interjection    errrrrrrrm
28 VB  verb, base form take
29 VBD verb, past tense    took
30 VBG verb, gerund/present participle taking
31 VBN verb, past participle   taken
32 VBP verb, sing. present, non-3d take
33 VBZ verb, 3rd person sing. present  takes
34 WDT wh-determiner   which
35 WP  wh-pronoun  who, what
36 WP$ possessive wh-pronoun   whose
37 WRB wh-abverb   where, when

第124-129行举个例子:

 1 c = [([4,5],[5],[6],[7],[8]),([3,6],[3],[7],[9],[56])] #假设c中两条batchs
 2 batch = list(zip(*c))   #len(batch):5,每项单独压缩
 3 lens = [len(x) for x in batch[0]]    #[2,2]
 4 unsorted_all = [lens] + [range(len(lens))] + list(batch)
 5 #[[2, 2], range(0, 2), ([4, 5], [3, 6]), ([5], [3]), ([6], [7]), ([7], [9]), ([8], [56])]
 6 tmp = list(zip(*sorted(zip(*unsorted_all), reverse=True)))
 7 #[[2, 2], [1, 0], [[3, 6], [4, 5]], [[3], [5]], [[7], [6]], [[9], [7]], [[56], [8]]]
 8 sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
 9 #[[2, 2], [1, 0], [[3, 6], [4, 5]], [[3], [5]], [[7], [6]], [[9], [7]], [[56], [8]]]
10 return sorted_all[2:], sorted_all[1]

 ./saved_model/00与./saved_model/01:

1 model_id = opt['id'] if len(opt['id']) > 1 else '0' + opt['id']
2 model_save_dir = opt['save_dir'] + '/' + model_id
3 opt['model_save_dir'] = model_save_dir
4 helper.ensure_dir(model_save_dir, verbose=True)#如果文件夹不存在则创建一个
5 def ensure_dir(d, verbose=True):
6     if not os.path.exists(d):
7         if verbose:
8             print("Directory {} do not exist; creating...".format(d))
9         os.makedirs(d)

保存config以及打印相关设置信息:

1 # save config
2 helper.save_config(opt, model_save_dir + '/config.json', verbose=True) #verbose=True 显示进程详细信息
3 vocab.save(model_save_dir + '/vocab.pkl')
4 file_logger = helper.FileLogger(model_save_dir + '/' + opt['log'], header="# epoch\ttrain_loss\tdev_loss\tdev_score\tbest_dev_score")
5 
6 # print model info
7 helper.print_config(opt)

如果未给opt['load']传递过参数,则opt['load']=False,第一次训练GCN模型,否则直接加载预训练好的模型。

 1 # model
 2 if not opt['load']:
 3     trainer = GCNTrainer(opt, emb_matrix=emb_matrix)
 4 else:
 5     # load pretrained model
 6     model_file = opt['model_file'] 
 7     print("Loading model from {}".format(model_file))
 8     model_opt = torch_utils.load_config(model_file)
 9     model_opt['optim'] = opt['optim']
10     trainer = GCNTrainer(model_opt)
11     trainer.load(model_file)  

其中GCNTrainer( )继承自Trainer类:

 1 class Trainer(object):
 2     def __init__(self, opt, emb_matrix=None):
 3         raise NotImplementedError
 4 
 5     def update(self, batch):
 6         raise NotImplementedError
 7 
 8     def predict(self, batch):
 9         raise NotImplementedError
10 
11     def update_lr(self, new_lr):
12         torch_utils.change_lr(self.optimizer, new_lr)
13 
14     def load(self, filename):
15         try:
16             checkpoint = torch.load(filename)
17         except BaseException:
18             print("Cannot load model from {}".format(filename))
19             exit()
20         self.model.load_state_dict(checkpoint['model'])
21         self.opt = checkpoint['config']
22 
23     def save(self, filename, epoch):
24         params = {
25                 'model': self.model.state_dict(),
26                 'config': self.opt,
27                 }
28         try:
29             torch.save(params, filename)
30             print("model saved to {}".format(filename))
31         except BaseException:
32             print("[Warning: Saving failed... continuing anyway.]")
33 
34 
35 def unpack_batch(batch, cuda):
36     if cuda:
37         inputs = [Variable(b.cuda()) for b in batch[:10]] #data[i:i+batch_size][:10]
38         labels = Variable(batch[10].cuda())
39     else:
40         inputs = [Variable(b) for b in batch[:10]]
41         labels = Variable(batch[10])
42     tokens = batch[0]
43     head = batch[5]
44     subj_pos = batch[6]
45     obj_pos = batch[7]
46     lens = batch[1].eq(0).long().sum(1).squeeze() #sum(1)按行求和
47     #data中每条数据POS标签对应的PAD_TOKEN:0的数目。[3,5,7,8,..,12]
48     return inputs, labels, tokens, head, subj_pos, obj_pos, lens
49 
50 class GCNTrainer(Trainer):
51     def __init__(self, opt, emb_matrix=None):
52         self.opt = opt
53         self.emb_matrix = emb_matrix
54         self.model = GCNClassifier(opt, emb_matrix=emb_matrix)
55         self.criterion = nn.CrossEntropyLoss()
56         self.parameters = [p for p in self.model.parameters() if p.requires_grad]
57         if opt['cuda']:
58             self.model.cuda()
59             self.criterion.cuda()
60         self.optimizer = torch_utils.get_optimizer(opt['optim'], self.parameters, opt['lr'])
61 
62     def update(self, batch):
63         inputs, labels, tokens, head, subj_pos, obj_pos, lens = unpack_batch(batch, self.opt['cuda'])
64 
65         # step forward
66         self.model.train()
67         self.optimizer.zero_grad()
68         logits, pooling_output = self.model(inputs)
69         loss = self.criterion(logits, labels)
70         #labels = [0,1,3,4,...,46,90] labels中为batch_size条数据对应的关系的id
71         #logits = [b,seq_len,42]
72         # l2 decay on all conv layers
73         if self.opt.get('conv_l2', 0) > 0:
74             loss += self.model.conv_l2() * self.opt['conv_l2']
75         # l2 penalty on output representations
76         if self.opt.get('pooling_l2', 0) > 0:
77             loss += self.opt['pooling_l2'] * (pooling_output ** 2).sum(1).mean()
78         loss_val = loss.item()
79         # backward
80         loss.backward()
81         torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt['max_grad_norm'])
82         self.optimizer.step()
83         return loss_val
84 
85     def predict(self, batch, unsort=True):
86         inputs, labels, tokens, head, subj_pos, obj_pos, lens = unpack_batch(batch, self.opt['cuda'])
87         orig_idx = batch[11]
88         # forward
89         self.model.eval()
90         logits, _ = self.model(inputs)
91         loss = self.criterion(logits, labels)
92         probs = F.softmax(logits, 1).data.cpu().numpy().tolist()
93         predictions = np.argmax(logits.data.cpu().numpy(), axis=1).tolist()
94         if unsort:
95             _, predictions, probs = [list(t) for t in zip(*sorted(zip(orig_idx,\
96                     predictions, probs)))]
97         return predictions, probs, loss.item()

训练前的一些准备:

1 id2label = dict([(v,k) for k,v in label2id.items()])
2 dev_score_history = []
3 current_lr = opt['lr']
4 
5 global_step = 0
6 global_start_time = time.time()
7 format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
8 max_steps = len(train_batch) * opt['num_epoch']    #len(train_batch)训练的batch有多少组

开始训练:

 1 # start training
 2 for epoch in range(1, opt['num_epoch']+1):
 3     train_loss = 0
 4     for i, batch in enumerate(train_batch):  #batch : data[i:i+batch_size]
 5         start_time = time.time()
 6         global_step += 1
 7         loss = trainer.update(batch)
 8         train_loss += loss
 9         if global_step % opt['log_step'] == 0:
10             duration = time.time() - start_time
11             print(format_str.format(datetime.now(), global_step, max_steps, epoch,\
12                     opt['num_epoch'], loss, duration, current_lr))
13 
14     # eval on dev
15     print("Evaluating on dev set...")
16     predictions = []
17     dev_loss = 0
18     for i, batch in enumerate(dev_batch):
19         preds, _, loss = trainer.predict(batch)
20         predictions += preds
21         dev_loss += loss
22     predictions = [id2label[p] for p in predictions]
23     train_loss = train_loss / train_batch.num_examples * opt['batch_size'] # avg loss per batch
24     dev_loss = dev_loss / dev_batch.num_examples * opt['batch_size']
25 
26     dev_p, dev_r, dev_f1 = scorer.score(dev_batch.gold(), predictions)
27     print("epoch {}: train_loss = {:.6f}, dev_loss = {:.6f}, dev_f1 = {:.4f}".format(epoch,\
28         train_loss, dev_loss, dev_f1))
29     dev_score = dev_f1
30     file_logger.log("{}\t{:.6f}\t{:.6f}\t{:.4f}\t{:.4f}".format(epoch, train_loss, dev_loss, dev_score, max([dev_score] + dev_score_history)))
31 
32     # save
33     model_file = model_save_dir + '/checkpoint_epoch_{}.pt'.format(epoch)
34     trainer.save(model_file, epoch)
35     if epoch == 1 or dev_score > max(dev_score_history):
36         copyfile(model_file, model_save_dir + '/best_model.pt')
37         print("new best model saved.")
38         file_logger.log("new best model saved at epoch {}: {:.2f}\t{:.2f}\t{:.2f}"\
39             .format(epoch, dev_p*100, dev_r*100, dev_score*100))
40     if epoch % opt['save_epoch'] != 0:
41         os.remove(model_file)
42 
43     # lr schedule
44     if len(dev_score_history) > opt['decay_epoch'] and dev_score <= dev_score_history[-1] and \
45             opt['optim'] in ['sgd', 'adagrad', 'adadelta']:
46         current_lr *= opt['lr_decay']
47         trainer.update_lr(current_lr)
48 
49     dev_score_history += [dev_score]
50     print("")
51 
52 print("Training ended with {} epochs.".format(epoch))

 

 

 参考:

pickle.dump和pickle.load:https://blog.csdn.net/weixin_38278334/article/details/82967813

斯坦福依存关系简写表:https://www.cnblogs.com/webbery/p/11357196.html

torch.eq:https://blog.csdn.net/gyt15663668337/article/details/95882646

推荐阅读