python - 序列化`tf.Module`s的推荐方法是什么?
问题描述
我有一个tf.Module
包含 (non-picklable)tf.keras.Model
作为子模块的类。我想知道tf.Module
在这种情况下推荐的序列化方法是什么?
我考虑了两种方法:
- 使用类似于
tf.keras.Model.save
. 我希望也许tf.Module
s 能够以相同的方式保存嵌套模块tf.Model.save
。但是tf.Module
,没有实现这样的事情。 - 酸洗,这将是一种序列化的简单方法
tf.Module
,但我不能这样做,因为它tf.keras.Model
是不可腌制的。
这是当前失败的示例代码:
import pickle
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
def main():
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
# Note, model *is not* picklable.
model = tf.keras.Model(x, y)
_ = model(tf.random.uniform((1, 3)))
module_1 = TestModule(model)
module_2 = pickle.loads(pickle.dumps(module_1))
for variable_1, variable_2 in zip(module_1.model.trainable_variables,
module_2.model.trainable_variables):
tf.debugging.assert_equal(variable_1, variable_2)
if __name__ == '__main__':
main()
我应该为每个编写自定义泡菜功能(例如__{get,set}state__
)tf.Module
还是应该创建一个类似的.save
方法keras.Model
?
解决方案
您可以使用Saved Model Format来保存自定义tf.Module
子类。
以下适用于 TensorFlow 2.1:
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(x, y)
module_1 = TestModule(model)
tf.saved_model.save(module_1, "./foo")
要加载回:
imported = tf.saved_model.load("foo")
断言
module_1 == imported
(或类似的)将AssertionError
在加载后引发,因为我们正在处理不同的 Tensorflow 对象。然而,我们可以迭代模型的权重并逐元素比较它们:
original_weights = module_1.model.weights
imported_weights = imported.model.variables.weights
for weight_idx, _ in enumerate(original_weights):
assert (
original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
).all()
推荐阅读
- xml - infopath xml 上的尾随空值
- string - 用字符串更改的不同值替换字符串中的值
- javascript - 我正在使用 javascript 在模式中创建一个表单,但它不适用于 ajax
- python - 使用方法不会改变我的对象?
- javascript - 如何强制用户从自动完成列表 php 和 jquery 中进行选择
- scala - 如何更改使用的默认时区:从 SQL Server 读取时 spark.read.jdbc
- html - 如果两个元素在可变宽度包装器中,则垂直对齐两个元素
- verilog - 始终不分配输出的循环
- java - 如何在列表视图中获取选中行的值,并获取首选文本字段?
- docker - Docker - 将 yum install 添加到基础镜像