python - 如何在 tensorflow 2.0 中保存和恢复自定义的 seq2seq 模型?
问题描述
我正在关注关于创建 seq2seq 以将英语音译为西班牙语的 tensorflow 教程。该教程发布在这里。本教程很有帮助且易于理解。但是,我很困惑如何恢复模型。我的理解是,可以通过在新的 python 会话中运行以下代码来恢复模型(编码器和解码器)。
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.lstm_layer = tf.keras.layers.LSTM(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
def call(self, x, hidden):
x = self.embedding(x)
output, h, c = self.lstm_layer(x, initial_state = hidden)
return output, h, c
def initialize_hidden_state(self):
return [tf.zeros((self.batch_sz, self.enc_units)), tf.zeros((self.batch_sz, self.enc_units))]
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, attention_type='luong'):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.attention_type = attention_type
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.fc = tf.keras.layers.Dense(vocab_size)
self.decoder_rnn_cell = tf.keras.layers.LSTMCell(self.dec_units)
self.sampler = tfa.seq2seq.sampler.TrainingSampler()
# Create attention mechanism with memory = None
self.attention_mechanism = self.build_attention_mechanism(self.dec_units,
None, self.batch_sz*[max_length_input], self.attention_type)
# Wrap attention mechanism with the fundamental rnn cell of decoder
self.rnn_cell = self.build_rnn_cell(batch_sz)
# Define the decoder with respect to fundamental rnn cell
self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler=self.sampler, output_layer=self.fc)
def build_rnn_cell(self, batch_sz):
rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnn_cell,
self.attention_mechanism, attention_layer_size=self.dec_units)
return rnn_cell
def build_attention_mechanism(self, dec_units, memory, memory_sequence_length, attention_type='luong'):
if(attention_type=='bahdanau'):
return tfa.seq2seq.BahdanauAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
else:
return tfa.seq2seq.LuongAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
def build_initial_state(self, batch_sz, encoder_state, Dtype):
decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=batch_sz, dtype=Dtype)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
return decoder_initial_state
def call(self, inputs, initial_state):
x = self.embedding(inputs)
outputs, _, _ = self.decoder(x, initial_state=initial_state, sequence_length=self.batch_sz*[max_length_output-1])
return outputs
## The variable can be saved from previous session
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, 'luong')
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam()
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
当我尝试恢复模型时,权重和嵌入不会在新会话中恢复。我也收到以下错误:
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
status.assert_consumed()
AssertionError: Unresolved object in checkpoint (root).encoder.embedding.embeddings: attributes {
name: "VARIABLE_VALUE"
full_name: "encoder_19/embedding_28/embeddings"
checkpoint_key: "encoder/embedding/embeddings/.ATTRIBUTES/VARIABLE_VALUE"
}
我在这里查看了 tensorflow 的 github 网站,但它不包括使用子类化构建的自定义模型。许多其他链接和stackoverflow也没有提供明确的解决方案。
请参阅链接以获取可重现的代码。我本质上是想了解如何在 tf 2.0 中保存和恢复自定义模型。
解决方案
推荐阅读
- html - 如何让页脚远离我的内容?
- html - HTML 标题表内的下拉列表
- ansible - Ansible 模块仅在文件不存在时才创建文件并写入一些数据
- python - 从 scipy dendrogram 中检索叶子颜色
- javascript - 如何在 Qt Creator 的 QWebEngineView 中捕获 javascript 错误?
- php - 如何通过 PHP 中的 AWS SNS 在 SMS 消息中发送 utf-8 字符?
- github - 如何从私有仓库 GitHub 获取信息
- azure - 尝试通过 terraform 添加 LinuxDiagnostic Azure VM Extension 并出现错误
- python - 我如何制作这样的“订购”产品应用程序
- php - 访问 Webroot 之外的文件/目录 - PHP