python - Seq2Seq 模型学会在几次迭代后只输出 EOS 令牌 (<\s>)
问题描述
我正在使用NMT创建一个在康奈尔电影对话语料库上训练的聊天机器人。
我的代码部分来自https://github.com/bshao001/ChatLearner和https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot
在训练过程中,我打印了一个从批次中馈送到解码器的随机输出答案,以及我的模型预测的相应答案,以观察学习进度。
我的问题:仅经过大约 4 次训练迭代后,模型就学会了<\s>
为每个时间步输出 EOS 令牌 ( )。即使在训练继续进行时,它也始终将其输出为响应(使用 logits 的 argmax 确定)。偶尔,很少,模型会输出一系列周期作为其答案。
我还在训练期间打印了前 10 个 logit 值(不仅仅是 argmax),以查看其中是否存在正确的单词,但它似乎是在预测词汇中最常见的单词(例如 i、you、?、. )。即使是这些前 10 个单词在训练期间也没有太大变化。
我已确保正确计算编码器和解码器的输入序列长度,并<s>
相应地添加了 SOS ( ) 和 EOS(也用于填充)标记。我还在损失计算中执行掩蔽。
这是一个示例输出:
训练迭代1:
Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s>
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration
administration winston winston winston magazines magazines magazines
magazines
...
训练迭代 4:
Decoder Input: <s> i guess i had it coming . let us call it settled .
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>
经过几次迭代后,它决定只预测 EOS(很少有一些时期)
我不确定是什么导致了这个问题,并且已经坚持了一段时间。任何帮助将不胜感激!
更新:我让它训练了十万次迭代,它仍然只输出 EOS(和偶尔的周期)。几次迭代后,训练损失也没有减少(从一开始就保持在 47 左右)
解决方案
最近我也在研究 seq2seq 模型。我之前遇到过你的问题,就我而言,我是通过更改损失函数来解决的。
你说你使用面具,所以我猜你tf.contrib.seq2seq.sequence_loss
像我一样使用。
我改为tf.nn.softmax_cross_entropy_with_logits
,它可以正常工作(并且计算成本更高)。
(编辑 2018 年 5 月 10 日。对不起,我需要编辑,因为我发现我的代码中有一个严重的错误)
tf.contrib.seq2seq.sequence_loss
可以很好地工作,如果 , , 的形状logits
是targets
正确mask
的。官方文档中定义:
tf.contrib.seq2seq.sequence_loss
loss=tf.contrib.seq2seq.sequence_loss(logits=decoder_logits,
targets=decoder_targets,
weights=masks)
#logits: [batch_size, sequence_length, num_decoder_symbols]
#targets: [batch_size, sequence_length]
#weights: [batch_size, sequence_length]
好吧,即使形状不符合它仍然可以工作。但结果可能很奇怪(很多#EOS #PAD ...等)。
由于decoder_outputs
, 和decoder_targets
可能具有与所需相同的形状(在我的情况下, mydecoder_targets
具有形状[sequence_length, batch_size]
)。所以试着用它tf.transpose
来帮助你重塑张量。
推荐阅读
- java - 如何在Java中的现有线程中执行Callable
- c# - 为什么字典总是新的初始化?
- matching - 如何使用 SPSS 模糊命令修复案例控制匹配
- sql - 如何制定一个允许用户从今天起只输入 DateTime 的 CHECK 约束?
- c++ - 在这个例子中,指针的指针做什么?
- django - 嵌套序列化程序中的媒体 URL 不完整
- javascript - 在 React Native Firebase 中重新加载应用程序后无法更新状态?
- jenkins - 如何修复詹金斯管道中的 JNLPLauncher 异常
- python - 始终在 Google Cloud 上运行 Python 脚本
- php - 具有适当权限的现有文件路径的 Nginx“没有此类文件或目录错误”