首页 > 解决方案 > BertForTokenClassification 未加载

问题描述

我试图从本地目录加载一个 Bert 模型,它显示错误我正在使用 cuda 10.0 版本和 pytorch 1.6.0

加载模型的代码:-

output_dir = './ner_model/'
model = BertForTokenClassification.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir)
model.to(device)

任何帮助将不胜感激

ReadError: invalid header

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
~\anaconda3\envs\env\lib\site-packages\transformers\modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    511             try:
--> 512                 state_dict = torch.load(resolved_archive_file, map_location="cpu")
    513             except Exception:

~\anaconda3\envs\env\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    385     try:
--> 386         return _load(f, map_location, pickle_module, **pickle_load_args)
    387     finally:

~\anaconda3\envs\env\lib\site-packages\torch\serialization.py in _load(f, map_location, pickle_module, **pickle_load_args)
    558                 # .zip is used for torch.jit.save and will throw an un-pickling error here
--> 559                 raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
    560             # if not a tarfile, reset file offset and proceed

RuntimeError: ./ner_model/pytorch_model.bin is a zip archive (did you mean to use torch.jit.load()?)

During handling of the above exception, another exception occurred:

OSError                                   Traceback (most recent call last)
<ipython-input-13-770da388c2c8> in <module>
     23 
     24 output_dir = './ner_model/'
---> 25 model = BertForTokenClassification.from_pretrained(output_dir)
     26 tokenizer = BertTokenizer.from_pretrained(output_dir)
     27 model.to(device)

~\anaconda3\envs\env\lib\site-packages\transformers\modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    513             except Exception:
    514                 raise OSError(
--> 515                     "Unable to load weights from pytorch checkpoint file. "
    516                     "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
    517                 )

OSError: Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, pleas

e 设置 from_tf=True。

标签: pythontensorflowpytorchloadbert-language-model

解决方案


推荐阅读