lstm - Tensorflow seq2seq 模型,损失值不错,但预测错误
问题描述
我正在尝试训练 Seq2Seq 模型。它应该将句子从 source_vocabulary 翻译成 target_vocabulary 中的句子。
损失值为 0.28,但网络不会从目标词汇表中预测单词。相反,网络的预测是负值。我不确定,如果代码中的某些内容是错误的,或者词汇量是否太大,或者我没有接受足够的训练。我用大约 270 000 个句子的数据集的一部分进行训练。即使损失值减少,我也不知道网络是否正在学习一些东西。
def encDecEmb():
batch_size = 32
seq_length = 40
vocab_size = 289415
epochs = 10
embedding_size = 300
hidden_units = 20
learning_rate = 0.001
#shape = batch_size, seq_length
encoder_inputs = tf.placeholder(
tf.int32, shape=(None, None), name='encoder_inputs')
decoder_inputs = tf.placeholder(
tf.int32, shape=(None, None), name='decoder_inputs')
sequence_length = tf.placeholder(tf.int32, [None], name='sequence_length')
# Embedding
embedding = tf.get_variable("embedding", [vocab_size, embedding_size])
encoder_embedding = tf.nn.embedding_lookup(embedding, encoder_inputs)
decoder_embedding = tf.nn.embedding_lookup(embedding, decoder_inputs)
# Encoder
encoder_cell = tf.contrib.rnn.LSTMCell(hidden_units)
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
encoder_cell, encoder_embedding, dtype=tf.float32)
projection_layer = Dense(vocab_size, use_bias=False)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding,
sequence_length=sequence_length)
# Decoder
decoder_cell = tf.contrib.rnn.LSTMCell(hidden_units)
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=decoder_cell, initial_state=encoder_final_state, helper=helper,
output_layer=projection_layer)
decoder_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
decoder)
logits = decoder_outputs.rnn_output
training_logits = tf.identity(decoder_outputs.rnn_output, name='logits')
target_labels = tf.placeholder(tf.int32, shape=(batch_size, seq_length))
weight_mask = tf.sequence_mask([i for i in range(
batch_size)], seq_length, dtype=tf.float32, name="weight_mask")
# loss
loss = tf.contrib.seq2seq.sequence_loss(
logits=training_logits, targets=decoder_inputs, weights=weight_mask)
#AdamOptimizer, Gradientclipping
optimizer = tf.train.AdamOptimizer(learning_rate)
gradients = optimizer.compute_gradients(loss)
capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var)
for grad, var in gradients if grad is not None]
train_opt = optimizer.apply_gradients(capped_gradients)
# read files
x = readCSV_to_int("./xTest.csv")
y = readCSV_to_int("./yTest.csv")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for epoch in range(epochs):
for batch, (inputs, targets) in enumerate(get_batches(x, y, batch_size)):
_, loss_value = sess.run([train_opt, loss],
feed_dict={encoder_inputs: inputs, decoder_inputs: targets, target_labels: targets,
sequence_length: [len(inputs[0])] * batch_size})
print('Epoch{:>3} Batch {:>4}/{} Loss {:>6.4f}'.format(epoch, batch, (len(x) // batch_size),
loss_value))
saver.save(sess, './model_on_testset')
print("Model Trained and Saved")
解决方案
It is perfectly O.K. that the logits are negative numbers, they are outputs of the network before normalizing by softmax, i.e., the numbers get exponentiated (and thus become small positive numbers) before they are normalized.
The architecture of your choice: vanilla sequence-to-sequence model is not the easiest one to train. Given you have only a limited number of training data (270k) I would use:
- A bidirectional encoder; and
- A decoder with the attention mechanism.
Attention simplifies the gradient flow from the decoder (where the actual loss computations happen) to the decoder as the gradients do not flow only via the encoder final state, but through attention to all encoder states.
There are also other things that influence the performance: e.g. how you segment your data (whether you use words or subwords) and how big your vocabulary you use.
Anyway, the best way how to find out if your model is learning something is to try to decode target sentences from the model.
推荐阅读
- c++ - 添加空白字符时,C ++电子邮件会中断
- android - FirebaseJobDispatcher:何时调用 JobService.onStopJob()
- json - 在 Go 中解组不一致的 JSON
- bokeh - 散景用选择或按钮改变圆圈颜色
- iphone - 从用户那里收集大量信息的最佳方式是什么?
- python - How to make every 'beginner' shown at random during the run of the program
- python - 可选的默认参数?
- php - 如何在 Marklogic 的光学 API 查询中使用用户在 Web 应用程序中设置的参数值?
- python - Numpy:尝试在数组切片的切片上设置值
- sql - 与属于父号码的某些号码分组