首页 > 解决方案 > 如何从 BertForSequenceClassification 获取 hidden_​​states?

问题描述

我阅读了官方教程(https://huggingface.co/transformers/model_doc/bert.html)并尝试设置配置,但它不起作用。

from transformers import PretrainedConfig
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.config.output_hidden_states = True
model.load_state_dict(torch.load('../parameter.pkl'))
model.cuda()
output = model(input)

标签: pythonpytorchbert-language-model

解决方案


输出应该是一个包含隐藏状态的列表。我希望因为您正在加载parameter.pkl默认情况下可能没有输出隐藏状态的内容,所以它会将您的内容覆盖config.output_hidden_states为 False?看看如果在加载 state_dict 后将其设置为 True 会发生什么?


推荐阅读