首页 > 解决方案 > 序列化`tf.Module`s的推荐方法是什么?

问题描述

我有一个tf.Module包含 (non-picklable)tf.keras.Model作为子模块的类。我想知道tf.Module在这种情况下推荐的序列化方法是什么?

我考虑了两种方法:

  1. 使用类似于tf.keras.Model.save. 我希望也许tf.Modules 能够以相同的方式保存嵌套模块tf.Model.save。但是tf.Module,没有实现这样的事情。
  2. 酸洗,这将是一种序列化的简单方法tf.Module,但我不能这样做,因为它tf.keras.Model是不可腌制的。

这是当前失败的示例代码:

import pickle

import tensorflow as tf


class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


def main():
    x = tf.keras.layers.Input((3, ))
    y = tf.keras.layers.Dense(5)(x)
    # Note, model *is not* picklable.
    model = tf.keras.Model(x, y)

    _ = model(tf.random.uniform((1, 3)))

    module_1 = TestModule(model)
    module_2 = pickle.loads(pickle.dumps(module_1))

    for variable_1, variable_2 in zip(module_1.model.trainable_variables,
                                      module_2.model.trainable_variables):
        tf.debugging.assert_equal(variable_1, variable_2)


if __name__ == '__main__':
    main()

我应该为每个编写自定义泡菜功能(例如__{get,set}state__tf.Module还是应该创建一个类似的.save方法keras.Model

标签: pythontensorflowkerastensorflow2.0

解决方案


您可以使用Saved Model Format来保存自定义tf.Module子类。

以下适用于 TensorFlow 2.1:

import tensorflow as tf

class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(x, y)
module_1 = TestModule(model)

tf.saved_model.save(module_1, "./foo")

要加载回:
imported = tf.saved_model.load("foo")

断言
module_1 == imported(或类似的)将AssertionError在加载后引发,因为我们正在处理不同的 Tensorflow 对象。然而,我们可以迭代模型的权重并逐元素比较它们:

original_weights = module_1.model.weights
imported_weights = imported.model.variables.weights

for weight_idx, _ in enumerate(original_weights):
  assert (
      original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
      ).all()

推荐阅读