首页 > 解决方案 > 如何在 tensorflow 2.0 中保存和恢复自定义的 seq2seq 模型?

问题描述

我正在关注关于创建 seq2seq 以将英语音译为西班牙语的 tensorflow 教程。该教程发布在这里。本教程很有帮助且易于理解。但是,我很困惑如何恢复模型。我的理解是,可以通过在新的 python 会话中运行以下代码来恢复模型(编码器和解码器)。

class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    self.batch_sz = batch_sz
    self.enc_units = enc_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

    self.lstm_layer = tf.keras.layers.LSTM(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

  def call(self, x, hidden):
    x = self.embedding(x)
    output, h, c = self.lstm_layer(x, initial_state = hidden)
    return output, h, c

  def initialize_hidden_state(self):
    return [tf.zeros((self.batch_sz, self.enc_units)), tf.zeros((self.batch_sz, self.enc_units))]


class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, attention_type='luong'):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.attention_type = attention_type

    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

    self.fc = tf.keras.layers.Dense(vocab_size)

    self.decoder_rnn_cell = tf.keras.layers.LSTMCell(self.dec_units)

    self.sampler = tfa.seq2seq.sampler.TrainingSampler()

    # Create attention mechanism with memory = None
    self.attention_mechanism = self.build_attention_mechanism(self.dec_units, 
                                                              None, self.batch_sz*[max_length_input], self.attention_type)

    # Wrap attention mechanism with the fundamental rnn cell of decoder
    self.rnn_cell = self.build_rnn_cell(batch_sz)

    # Define the decoder with respect to fundamental rnn cell
    self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler=self.sampler, output_layer=self.fc)

  def build_rnn_cell(self, batch_sz):
    rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnn_cell, 
                                  self.attention_mechanism, attention_layer_size=self.dec_units)
    return rnn_cell

  def build_attention_mechanism(self, dec_units, memory, memory_sequence_length, attention_type='luong'):

    if(attention_type=='bahdanau'):
      return tfa.seq2seq.BahdanauAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
    else:
      return tfa.seq2seq.LuongAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)

  def build_initial_state(self, batch_sz, encoder_state, Dtype):
    decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=batch_sz, dtype=Dtype)
    decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
    return decoder_initial_state

  def call(self, inputs, initial_state):
    x = self.embedding(inputs)
    outputs, _, _ = self.decoder(x, initial_state=initial_state, sequence_length=self.batch_sz*[max_length_output-1])
    return outputs

## The variable can be saved from previous session
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, 'luong')
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)

optimizer = tf.keras.optimizers.Adam()

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)


# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

当我尝试恢复模型时,权重和嵌入不会在新会话中恢复。我也收到以下错误:

status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

status.assert_consumed()

AssertionError: Unresolved object in checkpoint (root).encoder.embedding.embeddings: attributes {
  name: "VARIABLE_VALUE"
  full_name: "encoder_19/embedding_28/embeddings"
  checkpoint_key: "encoder/embedding/embeddings/.ATTRIBUTES/VARIABLE_VALUE"
}

我在这里查看了 tensorflow 的 github 网站,但它不包括使用子类化构建的自定义模型。许多其他链接和stackoverflow也没有提供明确的解决方案。

请参阅链接以获取可重现的代码。我本质上是想了解如何在 tf 2.0 中保存和恢复自定义模型。

标签: pythontensorflowkeras

解决方案


推荐阅读