首页 > 解决方案 > 如何在此代码段中训练嵌入矩阵?

问题描述

我正在关注使用双向 LSTM 实现 NER 标记器的 coursera 作业的代码。

但我无法理解嵌入矩阵是如何更新的。在下面的代码中,build_layers有一个变量embedding_matrix_variable作为 LSTM 的输入。但是,它不会在任何地方更新。

你能帮我理解嵌入是如何被训练的吗?

def build_layers(self, vocabulary_size, embedding_dim, n_hidden_rnn, n_tags):
    initial_embedding_matrix = np.random.randn(vocabulary_size, embedding_dim) / np.sqrt(embedding_dim)
    embedding_matrix_variable = tf.Variable(initial_embedding_matrix, name='embedding_matrix', dtype=tf.float32)

    forward_cell =  tf.nn.rnn_cell.DropoutWrapper(
        tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden_rnn, forget_bias=3.0),
        input_keep_prob=self.dropout_ph,
        output_keep_prob=self.dropout_ph,
        state_keep_prob=self.dropout_ph
    )

    backward_cell =  tf.nn.rnn_cell.DropoutWrapper(
        tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden_rnn, forget_bias=3.0),
        input_keep_prob=self.dropout_ph,
        output_keep_prob=self.dropout_ph,
        state_keep_prob=self.dropout_ph
    )

    embeddings = tf.nn.embedding_lookup(embedding_matrix_variable, self.input_batch)

    (rnn_output_fw, rnn_output_bw), _ =  tf.nn.bidirectional_dynamic_rnn(
        cell_fw=forward_cell, cell_bw=backward_cell,
        dtype=tf.float32,
        inputs=embeddings,
        sequence_length=self.lengths
    )

    rnn_output = tf.concat([rnn_output_fw, rnn_output_bw], axis=2)
    self.logits = tf.layers.dense(rnn_output, n_tags, activation=None)


def compute_loss(self, n_tags, PAD_index):
    """Computes masked cross-entopy loss with logits."""
    ground_truth_tags_one_hot = tf.one_hot(self.ground_truth_tags, n_tags)
    loss_tensor = tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_tags_one_hot, logits=self.logits)

    mask = tf.cast(tf.not_equal(self.input_batch, PAD_index), tf.float32)
    self.loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(loss_tensor, mask), axis=-1) / tf.reduce_sum(mask, axis=-1))

标签: pythontensorflowlstmword-embedding

解决方案


在 TensorFlow 中,变量通常不会直接更新(即通过手动将它们设置为某个值),而是使用优化算法和自动微分来训练它们。

当您定义tf.Variable时,您正在向计算图中添加一个节点(维护状态)。在训练时,如果损失节点取决于您定义的变量的状态,TensorFlow 将通过计算图自动遵循链式法则来计算损失函数相对于该变量的梯度。然后,优化算法将利用计算出的梯度来更新参与损失计算的可训练变量的值。

具体来说,您提供的代码构建了一个 TensorFlow 图,其中损失self.loss取决于权重embedding_matrix_variable(即图中这些节点之间存在路径),因此 TensorFlow 将计算关于该变量的梯度,优化器将在最小化损失时更新其值。使用TensorBoard检查 TensorFlow 图可能很有用。


推荐阅读