pytorch - 当我让 pytorch 关注时,你能给我一个解决损失的想法吗
问题描述
这是我的模型
class Build_Model(nn.Module):
def __init__(self,args) :
super(Build_Model, self).__init__()
self.hidden_size = args.dec_size
self.embedding = nn.Embedding(args.n_vocab, args.d_model)
self.enc_lstm = nn.LSTM(input_size =args.d_model, hidden_size=args.d_model,batch_first=True)
self.dec_lstm = nn.LSTM(input_size =args.d_model, hidden_size=args.d_model,batch_first=True)
self.soft_prob = nn.Softmax(dim=-1)
self.softmax_linear = nn.Linear(args.d_model*2,len(vocab))
self.softmax_linear_function = nn.Softmax(dim = -1)
def forward(self, enc_inputs, dec_inputs) :
enc_hidden = self.embedding(enc_inputs)
dec_hidden = self.embedding(dec_inputs)
enc_hidden , (enc_h_state,enc_c_state) = self.enc_lstm(enc_hidden)
dec_hidden,(dec_h_state,dec_c_state) = self.dec_lstm(dec_hidden,(enc_h_state,enc_c_state))
attn_score = torch.matmul(dec_hidden, torch.transpose(enc_hidden,2,1))
attn_prob = self.soft_prob(attn_score)
attn_out = torch.matmul(attn_prob,enc_hidden)
cat_hidden = torch.cat((attn_out, dec_hidden),-1)
y_pred = self.softmax_linear_function(self.softmax_linear(cat_hidden))
y_pred = torch.argmax(y_pred,dim =-1)
print('y_pred = ',y_pred.shape)
y_pred = y_pred.view(-1, 150)
print('2y_pred = ',y_pred.shape)
return y_pred
这是损失函数
def lm_loss(y_true, y_pred):
print(y_pred.shape)
y_pred_argmax = y_pred
#y_pred_argmax = y_pred_argmax.view(-1,150)
print(y_true.shape, y_pred_argmax.shape)
criterion = nn.CrossEntropyLoss(reduction="none")
loss = criterion(y_true.float(), y_pred_argmax.float()[0])
#mask = tf.not_equal(y_true, 0)
mask = torch.not_equal(y_pred_argmax,0)
#mask = tf.cast(mask, tf.float32)
mask = mask.type(torch.FloatTensor).to(device)
loss *= mask
#loss = tf.reduce_sum(loss) / tf.maximum(tf.reduce_sum(mask), 1)
loss = torch.sum(loss) / torch.maximum(torch.sum(mask),1)
return loss
最后是评价
optimizer.zero_grad()
print(train_enc_inputs.shape,train_dec_inputs.shape, train_dec_labels.shape )
y_pred = model(train_enc_inputs,train_dec_inputs)
#y_pred = torch.argmax(y_pred,dim =-1)
print(y_pred.shape )
loss = lm_loss(train_dec_labels, y_pred)
输出在这里:
torch.Size([32, 120]) torch.Size([32, 150]) torch.Size([32, 150])
y_pred = torch.Size([32, 150])
2y_pred = torch.Size([32, 150])
torch.Size([32, 150])
torch.Size([32, 150])
torch.Size([32, 150]) torch.Size([32, 150])
错误回溯:
ValueError Traceback (most recent call last)
<ipython-input-159-cc8976139dd5> in <module>()
9 #y_pred = torch.argmax(y_pred,dim =-1)
10 print(y_pred.shape )
---> 11 loss = lm_loss(train_dec_labels, y_pred)
12 n_step += 1
13 if n_step % 10 == 0:
3 frames
<ipython-input-158-39ba03042d04> in lm_loss(y_true, y_pred)
15 print(y_true.shape, y_pred_argmax.shape)
16 criterion = nn.CrossEntropyLoss(reduction="none")
---> 17 loss = criterion(y_true.float(), y_pred_argmax.float()[0])
18 #mask = tf.not_equal(y_true, 0)
19 mask = torch.not_equal(y_pred_argmax,0)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
1119 def forward(self, input: Tensor, target: Tensor) -> Tensor:
1120 return F.cross_entropy(input, target, weight=self.weight,
-> 1121 ignore_index=self.ignore_index, reduction=self.reduction)
1122
1123
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2822 if size_average is not None or reduce is not None:
2823 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2825
2826
ValueError: Expected input batch_size (32) to match target batch_size (150).
我该如何解决?
解决方案
您的使用存在一些问题nn.CrossEntropyLoss
:
你应该打电话
nn.CrossEntropyLoss
给criterion(y_pred, y_true)
,你似乎已经切换了两者。y_pred
包含您的网络的输出logits ,即它尚未通过softmax:您需要self.softmax_linear_function
在模型中删除)还
y_pred
应该包含所有组件,而不是argmax的结果。y_true
以密集格式传递:它包含真实的类标签,并且比预测少一维y_pred
。
推荐阅读
- javascript - analytics.js URL Calls
- select - Select cases if value is greater than mean of group
- python - PyHive with Kerberos throws Authentication error after few calls
- android - 通过 Utils.bitmapToMat 将 Bitmap 转换为 Mat 会改变它的颜色
- c# - 使用适用于 Windows Server 2016/9 的认证测试工具
- java - 带有子查询的预处理语句给出语法错误
- python - 我们可以统一用python制作游戏吗?
- r - 使用条件 For 循环对数据框中每一列的值求和
- react-native - Expo SDK 需要 Expo 才能运行
- angular - 如何在悬停时打开和关闭 Angular mat 菜单