pytorch - PyTorch nn.Transformer 学习复制目标
问题描述
我正在尝试使用 nn.Transformer 类训练 Transformer Seq2Seq 模型。我相信我实现它是错误的,因为当我训练它时,它似乎太快了,并且在推理过程中它经常重复自己。这似乎是解码器中的掩码问题,当我移除目标掩码时,训练性能是相同的。这让我相信我做错了目标掩蔽。这是我的模型代码:
class TransformerModel(nn.Module):
def __init__(self,
vocab_size, input_dim, heads, feedforward_dim, encoder_layers, decoder_layers,
sos_token, eos_token, pad_token, max_len=200, dropout=0.5,
device=(torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))):
super(TransformerModel, self).__init__()
self.target_mask = None
self.embedding = nn.Embedding(vocab_size, input_dim, padding_idx=pad_token)
self.pos_embedding = nn.Embedding(max_len, input_dim, padding_idx=pad_token)
self.transformer = nn.Transformer(
d_model=input_dim, nhead=heads, num_encoder_layers=encoder_layers,
num_decoder_layers=decoder_layers, dim_feedforward=feedforward_dim,
dropout=dropout)
self.out = nn.Sequential(
nn.Linear(input_dim, feedforward_dim),
nn.ReLU(),
nn.Linear(feedforward_dim, vocab_size))
self.device = device
self.max_len = max_len
self.sos_token = sos_token
self.eos_token = eos_token
# Initialize all weights to be uniformly distributed between -initrange and initrange
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
# Generate mask covering the top right triangle of a matrix
def generate_square_subsequent_mask(self, size):
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src, tgt):
# src: (Max source seq len, batch size, 1)
# tgt: (Max target seq len, batch size, 1)
# Embed source and target with normal and positional embeddings
embedded_src = (self.embedding(src) +
self.pos_embedding(
torch.arange(0, src.shape[1]).to(self.device).unsqueeze(0).repeat(src.shape[0], 1)))
# Generate target mask
target_mask = self.generate_square_subsequent_mask(size=tgt.shape[0]).to(self.device)
embedded_tgt = (self.embedding(tgt) +
self.pos_embedding(
torch.arange(0, tgt.shape[1]).to(self.device).unsqueeze(0).repeat(tgt.shape[0], 1)))
# Feed through model
outputs = self.transformer(src=embedded_src, tgt=embedded_tgt, tgt_mask=target_mask)
outputs = F.log_softmax(self.out(outputs), dim=-1)
return outputs
解决方案
对于那些有同样问题的人,我的问题是我没有正确地将 SOS 令牌添加到我提供模型的目标中,并将 EOS 令牌添加到我在损失函数中使用的目标中。
供参考:输入模型的目标应该是:[SOS] ....
并且用于损失的目标应该是:.... [EOS]
推荐阅读
- javascript - 文档完成(onLoad 事件)等到我网站的所有图像都加载完毕
- javascript - 'test/step_definitions/requester.js' 中的解析错误:(1:1):预期:#EOF、#Language、#TagLine、#FeatureLine、#Comment、#Empty
- mysql - 使用 Sequelize 和 mysql 加密密码列
- android - 服务屏幕尺寸
- listview - Xamarin Forms ListView 不显示内容
- ios - 存储的 UserDefaults 中的空 NSCFConstantString
- javascript - 在 div 加载上使用 jquery 格式化
- spring - Ribbon 如何检索服务的可用实例列表
- symfony - 树枝如何将树枝模板呈现为变量
- sql - 在循环中执行存储过程,直到使用 SSIS 满足条件