python - 如何在此代码段中训练嵌入矩阵?
问题描述
我正在关注使用双向 LSTM 实现 NER 标记器的 coursera 作业的代码。
但我无法理解嵌入矩阵是如何更新的。在下面的代码中,build_layers
有一个变量embedding_matrix_variable
作为 LSTM 的输入。但是,它不会在任何地方更新。
你能帮我理解嵌入是如何被训练的吗?
def build_layers(self, vocabulary_size, embedding_dim, n_hidden_rnn, n_tags):
initial_embedding_matrix = np.random.randn(vocabulary_size, embedding_dim) / np.sqrt(embedding_dim)
embedding_matrix_variable = tf.Variable(initial_embedding_matrix, name='embedding_matrix', dtype=tf.float32)
forward_cell = tf.nn.rnn_cell.DropoutWrapper(
tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden_rnn, forget_bias=3.0),
input_keep_prob=self.dropout_ph,
output_keep_prob=self.dropout_ph,
state_keep_prob=self.dropout_ph
)
backward_cell = tf.nn.rnn_cell.DropoutWrapper(
tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden_rnn, forget_bias=3.0),
input_keep_prob=self.dropout_ph,
output_keep_prob=self.dropout_ph,
state_keep_prob=self.dropout_ph
)
embeddings = tf.nn.embedding_lookup(embedding_matrix_variable, self.input_batch)
(rnn_output_fw, rnn_output_bw), _ = tf.nn.bidirectional_dynamic_rnn(
cell_fw=forward_cell, cell_bw=backward_cell,
dtype=tf.float32,
inputs=embeddings,
sequence_length=self.lengths
)
rnn_output = tf.concat([rnn_output_fw, rnn_output_bw], axis=2)
self.logits = tf.layers.dense(rnn_output, n_tags, activation=None)
def compute_loss(self, n_tags, PAD_index):
"""Computes masked cross-entopy loss with logits."""
ground_truth_tags_one_hot = tf.one_hot(self.ground_truth_tags, n_tags)
loss_tensor = tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_tags_one_hot, logits=self.logits)
mask = tf.cast(tf.not_equal(self.input_batch, PAD_index), tf.float32)
self.loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(loss_tensor, mask), axis=-1) / tf.reduce_sum(mask, axis=-1))
解决方案
在 TensorFlow 中,变量通常不会直接更新(即通过手动将它们设置为某个值),而是使用优化算法和自动微分来训练它们。
当您定义tf.Variable时,您正在向计算图中添加一个节点(维护状态)。在训练时,如果损失节点取决于您定义的变量的状态,TensorFlow 将通过计算图自动遵循链式法则来计算损失函数相对于该变量的梯度。然后,优化算法将利用计算出的梯度来更新参与损失计算的可训练变量的值。
具体来说,您提供的代码构建了一个 TensorFlow 图,其中损失self.loss
取决于权重embedding_matrix_variable
(即图中这些节点之间存在路径),因此 TensorFlow 将计算关于该变量的梯度,优化器将在最小化损失时更新其值。使用TensorBoard检查 TensorFlow 图可能很有用。
推荐阅读
- java - 如何将另一个方法中的变量调用到您的主方法
- javascript - 将 react-native-chart-kit 图表导出为 PDF
- amazon-web-services - AWS CloudWatch - 日志组不存在
- python - 如何在调度程序中向用户输入添加“if”语句 - 初学者问题
- sql - Impala:LIKE 不捕获 CONCAT 输出
- java - log4j2 没有从 MDC 获取 Syslog Appender 的新值
- javascript - 通过“Blob”下载文件时以角度更改文件名
- c# - 从 ASP .Net Core HTML 页面内的本地资源运行 exe 应用程序
- android - 隐式意图 - java.lang.NullPointerException:尝试在空对象引用上调用虚拟方法
- flutter - 自定义反序列化飞镖对象