首页 > 解决方案 > 在编码和解码中只有一个嵌入层的语言模型仅预测

问题描述

我试图让模型使用预训练的 Huggingface 的 BERT 作为特征提取器从句子中预测一个单词。模型看起来像这样

class BertAutoEncoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        decoder_layer = nn.TransformerDecoderLayer(768, 2, 1024, dropout=0.1)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, 2)
        self.fc = nn.Linear(768, vocab_size)

    def forward(self, memory, embedded_word):
        output = self.transformer_decoder(embedded_word, memory)
        output = self.fc(output)
        return output

当训练/评估时,我这样称呼模型

bert = BertModel.from_pretrained('bert-base-uncased')
bert.requires_grad_(False)
...
memory = bert(**src).last_hidden_state.transpose(0, 1)
embeded_word = bert.embeddings(trg.data['input_ids'][:, :-1], token_type_ids=trg.data['token_type_ids'][:, :-1]).transpose(0, 1)
output = model(memory, embeded_word)

损失很好地减少了,但结果模型只预测了<eos>令牌。

我尝试用 1 批 32 个样本训练模型,当损失减少通过时它确实有效,8e-6但是当我用所有数据训练它时,损失可能会超出这个范围,但保存的模型都不起作用。即使是 eval 或 train loss 在4e-6-附近的那个8e-6

令人惊讶的是,如果我像这样使用单独的解码器嵌入,该模型会起作用

class BertAutoEncoderOld(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        decoder_layer = nn.TransformerDecoderLayer(768, 2, 1024, dropout=0.1)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, 2)
        self.decoder = nn.Embedding(vocab_size, 768)
        self.pos_decoder = PositionalEncoding(768, 0.5)
        self.fc = nn.Linear(768, vocab_size)

    def forward(self, memory, word):
        tgt = self.decoder(word.data['input_ids'][:, :-1].transpose(0, 1))
        tgt = self.pos_decoder(tgt)
        output = self.transformer_decoder(tgt, memory)
        output = self.fc(output)
        return output

但我被要求让它与一个嵌入一起工作,我不知道如何。

我试过了

但这些都不起作用。

我做错了什么以及如何解决?

谢谢

根据@emily qeustion 进行编辑

我在整理功能中更改数据本身

text.data['attention_mask'][text.data['input_ids'] == 102] = 0
text.data['input_ids'][text.data['input_ids'] == 102] = 0
word.data['attention_mask'][word.data['input_ids'] == 102] = 0
word.data['input_ids'][word.data['input_ids'] == 102] = 0

它只在 Bert 中使用。

标签: nlppytorchhuggingface-transformerstransformer

解决方案


推荐阅读