python - 如何获得微调的 TFBertModel 的隐藏状态?
问题描述
我首先在文本分类任务上微调 Bert 模型,然后我想在 TensorFlow 中获得微调模型的嵌入。不幸的是,我只能说output_hidden_states=True
,在我下载预训练的 Bert 模型的第一行中,而不是在我创建tf.Keras.Model
. 这是我如何制作和训练模型的代码:
max_len = 55
from transformers import BertConfig, BertTokenizer, TFBertModel
def build_custome_model():
bert_encoder = TFBertModel.from_pretrained(Base_BERT_Path)
input_word_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_type_ids")
embedding = bert_encoder([input_word_ids, input_mask, input_type_ids])[0]
clf_output = embedding[:,0,:]
net = tf.keras.layers.Dropout(0.4)(clf_output)
output = tf.keras.layers.Dense(5, activation='softmax')(net)
model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids], outputs=output)
model.compile(tf.keras.optimizers.Adam(lr=1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
然后我在一个包含 2 个句子的数据集上训练模型,并作为它们的相似度得分
#------Training with stratifiedkfold-------
k = 5
kfold = StratifiedKFold(n_splits = k, shuffle = True)
for i, (train_idx, val_idx) in enumerate(kfold.split(first_sentences, labels.score), 1):
epoch_evaluation = {}
train_input = create_input(np.array(first_sentences)[train_idx], np.array(second_sentences)[train_idx], tokenizer, max_len=max_seq_length)
validation_input = create_input(np.array(first_sentences)[val_idx], np.array(second_sentences)[val_idx], tokenizer, max_len=max_seq_length)
history = model.fit(x = train_input, y = labels.loc[train_idx, 'score'],
validation_data= (validation_input, labels.loc[val_idx, 'score']),
epochs = 5,
verbose = 1,
batch_size = 8)
我的目标是在最后有一个在这个数据集上训练的模型,并且只要我给它一个句子就可以输出嵌入(隐藏状态的第一层(输出 [2] [0])),这样我就可以得到句子的所有微调标记嵌入的平均值。
解决方案
您可以使用get_input_embeddings函数检索嵌入:
model = build_custome_model():
model.layers[3].get_input_embeddings()(input_ids)
推荐阅读
- ios - 在 Iphone 上禁用 Progressive Web App 上的后退按钮
- netbeans - 来自数据库的实体类不映射来自不同模式的实体
- recursion - 等差数列数
- python - 将文件上传到烧瓶应用程序会给出 NotADirectoryError: [Errno 20] Not a directory
- python - AttributeError:“BoundingBoxesOnImage”对象没有属性“项目”
- python - 服务器和更多客户端之间的 Python NAT 打孔
- docker - 是否可以从 dockerfile 中删除工作目录
- material-ui - reactjs material ui - 换行而不是水平滚动
- sql - SQL Case 语句条件
- python - 下载时文件类型错误