首页 > 解决方案 > 在 TensorFlow 2.0 中使用 tf.keras.utils.plot_model()

问题描述

我需要可视化深度学习模型的输入/输出维度以进行调试。我在 Keras 功能 API 方面有经验,之前使用keras.utils.plot_model()过,这很有帮助。

现在我正在尝试迁移到 Tensorflow 2.0 - 主要是因为更模块化的模型定义等(你好 pytorch!)。但不确定如何tf.keras.utils.plot_model()在此架构中使用。下面的代码 -

class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    ...

  def call(self, x, hidden):
    ...

class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    ...

  def call(self, x, hidden, enc_output):
    ...

现在在训练模型时,检查点被保存

for epoch in range(EPOCHS):
    start = time.time()
    enc_hidden = encoder.initialize_hidden_state()
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(inp, targ, enc_hidden)
        total_loss += batch_loss

    if batch % 100 == 0:
        print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                   batch,
                                                   batch_loss.numpy()))
    # saving (checkpoint) the model 
    checkpoint.save(file_prefix = checkpoint_prefix)

    print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

我知道这个检查点有模型信息。但我不确定如何从这个检查点获得类似的可视化tf.keras.utils.plot_model()

请建议。

编辑

这就是我定义检查点的方式

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

然后将检查点保存为原始训练代码中所示

checkpoint.save(file_prefix = checkpoint_prefix)

标签: pythontensorflowdeep-learning

解决方案


推荐阅读