首页 > 解决方案 > 无法从 tensorflow 检查点读取以进行微调

问题描述

我正在尝试使用预训练的 BERT 模型对 SST2 数据处理器进行微调。但是当我给出预训练模型的检查点时,它显示“在检查点中找不到关键输出偏差”。

我认为这可能是由于预训练的 BERT 模型检查点中的错误。所以我又做了一次预训练。但是,我仍然面临同样的问题。

TASK = 'STS' #@param {type:\"string\"}
TASK_DATA_DIR = 'glue_data/STS-B/'# + TASK

output_dir = 'trained_model/observation'
tf.gfile.MakeDirs(output_dir)

BERT_MODEL = path + 'multi_cased_L-12_H-768_A-12/' 
VOCAB_FILE = os.path.join(BERT_MODEL, 'vocab.txt')   
CONFIG_FILE = os.path.join(BERT_MODEL, 'bert_config.json')   
INIT_CHECKPOINT = os.path.join(BERT_MODEL, 'bert_model.ckpt')   
DO_LOWER_CASE = BERT_MODEL.startswith('cased')

tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE, 
do_lower_case=DO_LOWER_CASE)

TRAIN_BATCH_SIZE = 1   
EVAL_BATCH_SIZE = 8   
PREDICT_BATCH_SIZE = 8   
LEARNING_RATE = 2e-5   
NUM_TRAIN_EPOCHS = 3.0   
MAX_SEQ_LENGTH = 128   

processors = {   
    "sts": run_classifier.StsProcessor,    
}   

processor = processors[TASK.lower()]()    
label_list = processor.get_labels()   

错误是:

NotFoundError:从检查点恢复失败。这很可能是由于检查点中缺少变量名称或其他图形键。请确保您没有根据检查点更改预期的图表。原始错误:在检查点 [[node save/RestoreV2(定义在 /home/subraas3/.conda/envs/tensorflow_13/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py: 1403)]] [[节点保存/RestoreV2(定义在/home/subraas3/.conda/envs/tensorflow_13/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1403)]]

标签: pythonpython-3.xtensorflownlp

解决方案


正如错误消息中指出的那样,如果发生此错误

  1. tf 计算图中的一个层被重命名。即预训练检查点中的层名称与提供的 API 中的名称不同
  2. API 中增加了一个新层,即改变了网络架构。或者,
  3. 预训练检查点中存在的层被删除(不相似)。

请检查 bert API 的版本是否与预训练的检查点版本相同。如果它们相同,您可能需要使用此工具手动检查检查点中的 tf 图是否与 API 一致。


推荐阅读