Training
To train a graph convolutional neural network (GCN) model, run:
1 bash train_gcn.sh 0 2 3 SAVE_ID=$1 4 python train.py --id $SAVE_ID --seed 0 --prune_k 1 --lr 0.3 --no-rnn --num_epoch 100 --pooling max --mlp_layers 2 --pooling_l2 0.003
Model checkpoints and logs will be saved to ./saved_models/00
.
To train a Contextualized GCN (C-GCN) model, run:
1 bash train_cgcn.sh 1 2 3 SAVE_ID=$1 4 python train.py --id $SAVE_ID --seed 0 --prune_k 1 --lr 0.3 --rnn_hidden 200 --num_epoch 100 --pooling max --mlp_layers 2 --pooling_l2 0.003
Model checkpoints and logs will be saved to ./saved_models/01
.
For details on the use of other parameters, such as the pruning distance k, please refer to train.py
.
可以看到C-GCN与GCN仅仅在RNN参数处存在区别,C-GCN多了RNN层。
下面阅读train.py主文件:
1 parser = argparse.ArgumentParser() 2 parser.add_argument('--data_dir', type=str, default='dataset/tacred') 3 parser.add_argument('--vocab_dir', type=str, default='dataset/vocab') 4 parser.add_argument('--emb_dim', type=int, default=300, help='Word embedding dimension.') 5 parser.add_argument('--ner_dim', type=int, default=30, help='NER embedding dimension.') 6 parser.add_argument('--pos_dim', type=int, default=30, help='POS embedding dimension.') 7 parser.add_argument('--hidden_dim', type=int, default=200, help='RNN hidden state size.') 8 parser.add_argument('--num_layers', type=int, default=2, help='Num of RNN layers.') 9 parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.') 10 parser.add_argument('--gcn_dropout', type=float, default=0.5, help='GCN layer dropout rate.') 11 parser.add_argument('--word_dropout', type=float, default=0.04, help='The rate at which randomly set a word to UNK.') 12 parser.add_argument('--topn', type=int, default=1e10, help='Only finetune top N word embeddings.') 13 parser.add_argument('--lower', dest='lower', action='store_true', help='Lowercase all words.') 14 parser.add_argument('--no-lower', dest='lower', action='store_false') 15 parser.set_defaults(lower=False) #set_defaults()可以设置一些参数的默认值 16 17 parser.add_argument('--prune_k', default=-1, type=int, help='Prune the dependency tree to <= K distance off the dependency path; set to -1 for no pruning.') 18 parser.add_argument('--conv_l2', type=float, default=0, help='L2-weight decay on conv layers only.') 19 parser.add_argument('--pooling', choices=['max', 'avg', 'sum'], default='max', help='Pooling function type. Default max.') 20 parser.add_argument('--pooling_l2', type=float, default=0, help='L2-penalty for all pooling output.') 21 parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.') 22 parser.add_argument('--no_adj', dest='no_adj', action='store_true', help="Zero out adjacency matrix for ablation.")#取零邻接矩阵进行消融 23 24 parser.add_argument('--no-rnn', dest='rnn', action='store_false', help='Do not use RNN layer.') #dest相当于为参数起别名 25 parser.add_argument('--rnn_hidden', type=int, default=200, help='RNN hidden state size.') 26 parser.add_argument('--rnn_layers', type=int, default=1, help='Number of RNN layers.') 27 parser.add_argument('--rnn_dropout', type=float, default=0.5, help='RNN dropout rate.') 28 29 parser.add_argument('--lr', type=float, default=1.0, help='Applies to sgd and adagrad.') 30 parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate decay rate.') 31 parser.add_argument('--decay_epoch', type=int, default=5, help='Decay learning rate after this epoch.') 32 parser.add_argument('--optim', choices=['sgd', 'adagrad', 'adam', 'adamax'], default='sgd', help='Optimizer: sgd, adagrad, adam or adamax.') 33 parser.add_argument('--num_epoch', type=int, default=100, help='Number of total training epochs.') 34 parser.add_argument('--batch_size', type=int, default=50, help='Training batch size.') 35 parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')#防止梯度爆炸 36 parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') 37 parser.add_argument('--log', type=str, default='logs.txt', help='Write training log to file.') 38 parser.add_argument('--save_epoch', type=int, default=100, help='Save model checkpoints every k epochs.') 39 parser.add_argument('--save_dir', type=str, default='./saved_models', help='Root dir for saving models.') 40 parser.add_argument('--id', type=str, default='00', help='Model ID under which to save models.') 41 parser.add_argument('--info', type=str, default='', help='Optional info for the experiment.') 42 43 parser.add_argument('--seed', type=int, default=1234) 44 parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) 45 parser.add_argument('--cpu', action='store_true', help='Ignore CUDA.') 46 47 parser.add_argument('--load', dest='load', action='store_true', help='Load pretrained model.') 48 parser.add_argument('--model_file', type=str, help='Filename of the pretrained model.') 49 50 args = parser.parse_args()
一些CPU/GPU设置:
1 torch.manual_seed(args.seed) #为CPU设置种子用于生成随机数,以使得结果是确定的 2 np.random.seed(args.seed) 3 random.seed(1234) 4 if args.cpu: 5 args.cuda = False 6 elif args.cuda: 7 torch.cuda.manual_seed(args.seed) #为当前GPU设置随机种子 8 init_time = time.time() 9 10 # make opt 11 opt = vars(args) #把args中的参数像字典一样,键引用值,当然,参数名和所对应的值用的是个"不可见"的字典。我们可以使用 vars 函数来返回这个字典: 12 label2id = constant.LABEL_TO_ID #关系对应的id (1+41) 13 opt['num_class'] = len(label2id)
加载vocab:
1 # load vocab 2 vocab_file = opt['vocab_dir'] + '/vocab.pkl' 3 vocab = Vocab(vocab_file, load=True) 4 opt['vocab_size'] = vocab.size 5 emb_file = opt['vocab_dir'] + '/embedding.npy' 6 emb_matrix = np.load(emb_file) 7 assert emb_matrix.shape[0] == vocab.size 8 assert emb_matrix.shape[1] == opt['emb_dim']
其中Vocab( )类:
1 class Vocab(object): 2 def __init__(self, filename, load=False, word_counter=None, threshold=0): 3 if load: 4 assert os.path.exists(filename), "Vocab file does not exist at " + filename 5 # load from file and ignore all other params 6 self.id2word, self.word2id = self.load(filename) 7 self.size = len(self.id2word) 8 print("Vocab size {} loaded from file".format(self.size)) 9 else: 10 print("Creating vocab from scratch...") 11 assert word_counter is not None, "word_counter is not provided for vocab creation." 12 self.word_counter = word_counter 13 if threshold > 1: 14 # remove words that occur less than thres 15 self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold]) 16 self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True) 17 # add special tokens to the beginning 18 self.id2word = [constant.PAD_TOKEN, constant.UNK_TOKEN] + self.id2word 19 self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))]) 20 self.size = len(self.id2word) 21 self.save(filename) 22 print("Vocab size {} saved to file {}".format(self.size, filename)) 23 24 def load(self, filename): 25 with open(filename, 'rb') as infile: #rb: read binary 以可读二进制方式打开 wb: write binary 以可写二进制方式打开 26 id2word = pickle.load(infile) 27 word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))]) 28 return id2word, word2id 29 30 def save(self, filename): 31 if os.path.exists(filename): 32 print("Overwriting old vocab file at " + filename) 33 os.remove(filename) 34 with open(filename, 'wb') as outfile: 35 pickle.dump(self.id2word, outfile) 36 return
训练集与验证集获取:
1 # load data 2 print("Loading data from {} with batch size {}...".format(opt['data_dir'], opt['batch_size'])) 3 train_batch = DataLoader(opt['data_dir'] + '/train.json', opt['batch_size'], opt, vocab, evaluation=False) 4 dev_batch = DataLoader(opt['data_dir'] + '/dev.json', opt['batch_size'], opt, vocab, evaluation=True)
其中DataLoader( ):
1 class DataLoader(object): 2 """ 3 Load data from json files, preprocess and prepare batches. 4 """ 5 def __init__(self, filename, batch_size, opt, vocab, evaluation=False): 6 self.batch_size = batch_size 7 self.opt = opt 8 self.vocab = vocab 9 self.eval = evaluation 10 self.label2id = constant.LABEL_TO_ID 11 12 with open(filename) as infile: 13 data = json.load(infile) 14 self.raw_data = data #json文件数据,生肉 15 data = self.preprocess(data, vocab, opt) 16 17 # shuffle for training 18 if not evaluation: 19 indices = list(range(len(data))) 20 random.shuffle(indices) 21 data = [data[i] for i in indices] 22 self.id2label = dict([(v,k) for k,v in self.label2id.items()]) 23 self.labels = [self.id2label[d[-1]] for d in data] 24 #preprocess后的数据 d[-1]为关系类型对应的id,这里都取出来转回关系单词成一个集合 25 self.num_examples = len(data) 26 27 # chunk into batches 28 data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] 29 self.data = data 30 print("{} batches created for {}".format(len(data), filename)) 31 32 def preprocess(self, data, vocab, opt): 33 """ Preprocess the data and convert to ids. """ 34 processed = [] 35 for d in data: 36 tokens = list(d['token']) 37 if opt['lower']: 38 tokens = [t.lower() for t in tokens] 39 # anonymize tokens 40 ss, se = d['subj_start'], d['subj_end'] 41 os, oe = d['obj_start'], d['obj_end'] 42 tokens[ss:se+1] = ['SUBJ-'+d['subj_type']] * (se-ss+1) #标签替代实体处的单词 43 tokens[os:oe+1] = ['OBJ-'+d['obj_type']] * (oe-os+1) 44 tokens = map_to_ids(tokens, vocab.word2id) 45 pos = map_to_ids(d['stanford_pos'], constant.POS_TO_ID) #词性标签->id 46 ner = map_to_ids(d['stanford_ner'], constant.NER_TO_ID) #命名实体类型标签->id 47 deprel = map_to_ids(d['stanford_deprel'], constant.DEPREL_TO_ID) #依赖树标签->id 48 head = [int(x) for x in d['stanford_head']] #??? 49 assert any([x == 0 for x in head]) 50 l = len(tokens) 51 subj_positions = get_positions(d['subj_start'], d['subj_end'], l) #返回距离List :[-3,-2,-1,0,0,0,1,2,3] 52 obj_positions = get_positions(d['obj_start'], d['obj_end'], l) #0,0,0处是实体 53 subj_type = [constant.SUBJ_NER_TO_ID[d['subj_type']]] #主语实体类型->id 54 obj_type = [constant.OBJ_NER_TO_ID[d['obj_type']]] 55 relation = self.label2id[d['relation']] #关系类型->id 56 processed += [(tokens, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, relation)] 57 return processed 58 59 def gold(self): 60 """ Return gold labels as a list. """ 61 return self.labels 62 63 def __len__(self): 64 return len(self.data) 65 66 def __getitem__(self, key): 67 """ Get a batch with index. """ 68 if not isinstance(key, int): 69 raise TypeError 70 if key < 0 or key >= len(self.data): 71 raise IndexError 72 batch = self.data[key] 73 batch_size = len(batch) # key取几个数,则batch_size为几 74 batch = list(zip(*batch)) 75 assert len(batch) == 10 76 77 # sort all fields by lens for easy RNN operations 78 lens = [len(x) for x in batch[0]] #每个batch中对应的单词数目 79 batch, orig_idx = sort_all(batch, lens) #按照batch内的单词从多到少顺序返回batch, 80 81 # word dropout 82 if not self.eval: 83 words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] 84 else: 85 words = batch[0] 86 87 # convert to tensors 88 words = get_long_tensor(words, batch_size) 89 masks = torch.eq(words, 0) #words 矩阵中pad的位置为0,这些位置因为=0所以被mask为1 90 pos = get_long_tensor(batch[1], batch_size) 91 ner = get_long_tensor(batch[2], batch_size) 92 deprel = get_long_tensor(batch[3], batch_size) 93 head = get_long_tensor(batch[4], batch_size) 94 subj_positions = get_long_tensor(batch[5], batch_size) 95 obj_positions = get_long_tensor(batch[6], batch_size) 96 subj_type = get_long_tensor(batch[7], batch_size) 97 obj_type = get_long_tensor(batch[8], batch_size) 98 99 rels = torch.LongTensor(batch[9]) 100 101 return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx) 102 103 def __iter__(self): 104 for i in range(self.__len__()): 105 yield self.__getitem__(i) 106 107 def map_to_ids(tokens, vocab): 108 ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens] 109 return ids 110 111 def get_positions(start_idx, end_idx, length): 112 """ Get subj/obj position sequence. """ 113 return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \ 114 list(range(1, length-end_idx)) 115 116 def get_long_tensor(tokens_list, batch_size): 117 """ Convert list of list of tokens to a padded LongTensor. """ 118 token_len = max(len(x) for x in tokens_list) 119 tokens = torch.LongTensor(batch_size, token_len).fill_(constant.PAD_ID) 120 for i, s in enumerate(tokens_list): 121 tokens[i, :len(s)] = torch.LongTensor(s) 122 return tokens 123 124 def sort_all(batch, lens): 125 """ Sort all fields by descending order of lens, and return the original indices. """ 126 unsorted_all = [lens] + [range(len(lens))] + list(batch) 127 sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))] 128 #sorted没指定key,默认按第一维度作为索引排序 129 return sorted_all[2:], sorted_all[1] 130 131 def word_dropout(tokens, dropout): 132 """ Randomly dropout tokens (IDs) and replace them with <UNK> tokens. """ 133 return [constant.UNK_ID if x != constant.UNK_ID and np.random.random() < dropout \ 134 else x for x in tokens]
第45行词性标签不同词性对应的含义:
1 POS tag list: 2 3 CC coordinating conjunction 4 CD cardinal digit 5 DT determiner 6 EX existential there (like: "there is" ... think of it like "there exists") 7 FW foreign word 8 IN preposition/subordinating conjunction 9 JJ adjective 'big' 10 JJR adjective, comparative 'bigger' 11 JJS adjective, superlative 'biggest' 12 LS list marker 1) 13 MD modal could, will 14 NN noun, singular 'desk' 15 NNS noun plural 'desks' 16 NNP proper noun, singular 'Harrison' 17 NNPS proper noun, plural 'Americans' 18 PDT predeterminer 'all the kids' 19 POS possessive ending parent's 20 PRP personal pronoun I, he, she 21 PRP$ possessive pronoun my, his, hers 22 RB adverb very, silently, 23 RBR adverb, comparative better 24 RBS adverb, superlative best 25 RP particle give up 26 TO to go 'to' the store. 27 UH interjection errrrrrrrm 28 VB verb, base form take 29 VBD verb, past tense took 30 VBG verb, gerund/present participle taking 31 VBN verb, past participle taken 32 VBP verb, sing. present, non-3d take 33 VBZ verb, 3rd person sing. present takes 34 WDT wh-determiner which 35 WP wh-pronoun who, what 36 WP$ possessive wh-pronoun whose 37 WRB wh-abverb where, when
第124-129行举个例子:
1 c = [([4,5],[5],[6],[7],[8]),([3,6],[3],[7],[9],[56])] #假设c中两条batchs 2 batch = list(zip(*c)) #len(batch):5,每项单独压缩 3 lens = [len(x) for x in batch[0]] #[2,2] 4 unsorted_all = [lens] + [range(len(lens))] + list(batch) 5 #[[2, 2], range(0, 2), ([4, 5], [3, 6]), ([5], [3]), ([6], [7]), ([7], [9]), ([8], [56])] 6 tmp = list(zip(*sorted(zip(*unsorted_all), reverse=True))) 7 #[[2, 2], [1, 0], [[3, 6], [4, 5]], [[3], [5]], [[7], [6]], [[9], [7]], [[56], [8]]] 8 sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))] 9 #[[2, 2], [1, 0], [[3, 6], [4, 5]], [[3], [5]], [[7], [6]], [[9], [7]], [[56], [8]]] 10 return sorted_all[2:], sorted_all[1]
./saved_model/00与./saved_model/01:
1 model_id = opt['id'] if len(opt['id']) > 1 else '0' + opt['id'] 2 model_save_dir = opt['save_dir'] + '/' + model_id 3 opt['model_save_dir'] = model_save_dir 4 helper.ensure_dir(model_save_dir, verbose=True)#如果文件夹不存在则创建一个 5 def ensure_dir(d, verbose=True): 6 if not os.path.exists(d): 7 if verbose: 8 print("Directory {} do not exist; creating...".format(d)) 9 os.makedirs(d)
保存config以及打印相关设置信息:
1 # save config 2 helper.save_config(opt, model_save_dir + '/config.json', verbose=True) #verbose=True 显示进程详细信息 3 vocab.save(model_save_dir + '/vocab.pkl') 4 file_logger = helper.FileLogger(model_save_dir + '/' + opt['log'], header="# epoch\ttrain_loss\tdev_loss\tdev_score\tbest_dev_score") 5 6 # print model info 7 helper.print_config(opt)
如果未给opt['load']传递过参数,则opt['load']=False,第一次训练GCN模型,否则直接加载预训练好的模型。
1 # model 2 if not opt['load']: 3 trainer = GCNTrainer(opt, emb_matrix=emb_matrix) 4 else: 5 # load pretrained model 6 model_file = opt['model_file'] 7 print("Loading model from {}".format(model_file)) 8 model_opt = torch_utils.load_config(model_file) 9 model_opt['optim'] = opt['optim'] 10 trainer = GCNTrainer(model_opt) 11 trainer.load(model_file)
其中GCNTrainer( )继承自Trainer类:
1 class Trainer(object): 2 def __init__(self, opt, emb_matrix=None): 3 raise NotImplementedError 4 5 def update(self, batch): 6 raise NotImplementedError 7 8 def predict(self, batch): 9 raise NotImplementedError 10 11 def update_lr(self, new_lr): 12 torch_utils.change_lr(self.optimizer, new_lr) 13 14 def load(self, filename): 15 try: 16 checkpoint = torch.load(filename) 17 except BaseException: 18 print("Cannot load model from {}".format(filename)) 19 exit() 20 self.model.load_state_dict(checkpoint['model']) 21 self.opt = checkpoint['config'] 22 23 def save(self, filename, epoch): 24 params = { 25 'model': self.model.state_dict(), 26 'config': self.opt, 27 } 28 try: 29 torch.save(params, filename) 30 print("model saved to {}".format(filename)) 31 except BaseException: 32 print("[Warning: Saving failed... continuing anyway.]") 33 34 35 def unpack_batch(batch, cuda): 36 if cuda: 37 inputs = [Variable(b.cuda()) for b in batch[:10]] #data[i:i+batch_size][:10] 38 labels = Variable(batch[10].cuda()) 39 else: 40 inputs = [Variable(b) for b in batch[:10]] 41 labels = Variable(batch[10]) 42 tokens = batch[0] 43 head = batch[5] 44 subj_pos = batch[6] 45 obj_pos = batch[7] 46 lens = batch[1].eq(0).long().sum(1).squeeze() #sum(1)按行求和 47 #data中每条数据POS标签对应的PAD_TOKEN:0的数目。[3,5,7,8,..,12] 48 return inputs, labels, tokens, head, subj_pos, obj_pos, lens 49 50 class GCNTrainer(Trainer): 51 def __init__(self, opt, emb_matrix=None): 52 self.opt = opt 53 self.emb_matrix = emb_matrix 54 self.model = GCNClassifier(opt, emb_matrix=emb_matrix) 55 self.criterion = nn.CrossEntropyLoss() 56 self.parameters = [p for p in self.model.parameters() if p.requires_grad] 57 if opt['cuda']: 58 self.model.cuda() 59 self.criterion.cuda() 60 self.optimizer = torch_utils.get_optimizer(opt['optim'], self.parameters, opt['lr']) 61 62 def update(self, batch): 63 inputs, labels, tokens, head, subj_pos, obj_pos, lens = unpack_batch(batch, self.opt['cuda']) 64 65 # step forward 66 self.model.train() 67 self.optimizer.zero_grad() 68 logits, pooling_output = self.model(inputs) 69 loss = self.criterion(logits, labels) 70 #labels = [0,1,3,4,...,46,90] labels中为batch_size条数据对应的关系的id 71 #logits = [b,seq_len,42] 72 # l2 decay on all conv layers 73 if self.opt.get('conv_l2', 0) > 0: 74 loss += self.model.conv_l2() * self.opt['conv_l2'] 75 # l2 penalty on output representations 76 if self.opt.get('pooling_l2', 0) > 0: 77 loss += self.opt['pooling_l2'] * (pooling_output ** 2).sum(1).mean() 78 loss_val = loss.item() 79 # backward 80 loss.backward() 81 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt['max_grad_norm']) 82 self.optimizer.step() 83 return loss_val 84 85 def predict(self, batch, unsort=True): 86 inputs, labels, tokens, head, subj_pos, obj_pos, lens = unpack_batch(batch, self.opt['cuda']) 87 orig_idx = batch[11] 88 # forward 89 self.model.eval() 90 logits, _ = self.model(inputs) 91 loss = self.criterion(logits, labels) 92 probs = F.softmax(logits, 1).data.cpu().numpy().tolist() 93 predictions = np.argmax(logits.data.cpu().numpy(), axis=1).tolist() 94 if unsort: 95 _, predictions, probs = [list(t) for t in zip(*sorted(zip(orig_idx,\ 96 predictions, probs)))] 97 return predictions, probs, loss.item()
训练前的一些准备:
1 id2label = dict([(v,k) for k,v in label2id.items()]) 2 dev_score_history = [] 3 current_lr = opt['lr'] 4 5 global_step = 0 6 global_start_time = time.time() 7 format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}' 8 max_steps = len(train_batch) * opt['num_epoch'] #len(train_batch)训练的batch有多少组
开始训练:
1 # start training 2 for epoch in range(1, opt['num_epoch']+1): 3 train_loss = 0 4 for i, batch in enumerate(train_batch): #batch : data[i:i+batch_size] 5 start_time = time.time() 6 global_step += 1 7 loss = trainer.update(batch) 8 train_loss += loss 9 if global_step % opt['log_step'] == 0: 10 duration = time.time() - start_time 11 print(format_str.format(datetime.now(), global_step, max_steps, epoch,\ 12 opt['num_epoch'], loss, duration, current_lr)) 13 14 # eval on dev 15 print("Evaluating on dev set...") 16 predictions = [] 17 dev_loss = 0 18 for i, batch in enumerate(dev_batch): 19 preds, _, loss = trainer.predict(batch) 20 predictions += preds 21 dev_loss += loss 22 predictions = [id2label[p] for p in predictions] 23 train_loss = train_loss / train_batch.num_examples * opt['batch_size'] # avg loss per batch 24 dev_loss = dev_loss / dev_batch.num_examples * opt['batch_size'] 25 26 dev_p, dev_r, dev_f1 = scorer.score(dev_batch.gold(), predictions) 27 print("epoch {}: train_loss = {:.6f}, dev_loss = {:.6f}, dev_f1 = {:.4f}".format(epoch,\ 28 train_loss, dev_loss, dev_f1)) 29 dev_score = dev_f1 30 file_logger.log("{}\t{:.6f}\t{:.6f}\t{:.4f}\t{:.4f}".format(epoch, train_loss, dev_loss, dev_score, max([dev_score] + dev_score_history))) 31 32 # save 33 model_file = model_save_dir + '/checkpoint_epoch_{}.pt'.format(epoch) 34 trainer.save(model_file, epoch) 35 if epoch == 1 or dev_score > max(dev_score_history): 36 copyfile(model_file, model_save_dir + '/best_model.pt') 37 print("new best model saved.") 38 file_logger.log("new best model saved at epoch {}: {:.2f}\t{:.2f}\t{:.2f}"\ 39 .format(epoch, dev_p*100, dev_r*100, dev_score*100)) 40 if epoch % opt['save_epoch'] != 0: 41 os.remove(model_file) 42 43 # lr schedule 44 if len(dev_score_history) > opt['decay_epoch'] and dev_score <= dev_score_history[-1] and \ 45 opt['optim'] in ['sgd', 'adagrad', 'adadelta']: 46 current_lr *= opt['lr_decay'] 47 trainer.update_lr(current_lr) 48 49 dev_score_history += [dev_score] 50 print("") 51 52 print("Training ended with {} epochs.".format(epoch))
参考:
pickle.dump和pickle.load:https://blog.csdn.net/weixin_38278334/article/details/82967813
斯坦福依存关系简写表:https://www.cnblogs.com/webbery/p/11357196.html
torch.eq:https://blog.csdn.net/gyt15663668337/article/details/95882646