python-3.x - 有没有办法从模型子类 API 中保存在 tensorflow 2.0 中构建的 Keras 模型?
问题描述
有没有办法在训练完成后使用 tf.keras 模型子类化 API 保存整个模型构建?我知道我们可以使用 save_weights 只保存权重,但是有没有办法保存整个模型,以便以后在没有可用代码时可以使用它进行预测?
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define your layers here.
self.dense_1 = layers.Dense(32, activation='relu')
self.dense_2 = layers.Dense(num_classes, activation='sigmoid')
def call(self, inputs):
# Define your forward pass here,
# using layers you previously defined (in `__init__`).
x = self.dense_1(inputs)
return self.dense_2(x)
model = MyModel(num_classes=10)
# The compile step specifies the training configuration.
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, batch_size=32, epochs=5)
解决方案
您可以使用以下步骤在训练、加载和推理后保存模型:
训练后保存模型
model.save(filepath="model")
# OR
tf.keras.models.save_model(model, filepath="model_")
加载保存的模型
loaded_model = tf.keras.models.load_model(filepath="model_")
使用加载模型进行预测
result = loaded_model.predict(test_db)
推荐阅读
- java - Quarkus 异步
- c# - SQLite考勤日期记录表结构| 删除位置 | 首要的关键
- c++ - 内联函数的函数本地静态对象是否在共享对象文件之间共享?
- swift - 如何在 swift 中以编程方式隐藏 tabBarController 的顶部栏?
- c++11 - 在 Codelite C++ 中,当我尝试同时声明一个向量并对其进行初始化时,我得到一个错误,尽管它在 Xcode 中出于某种原因工作
- r - 在 R 中使用 ggplot2 创建两个变量的有向箭头图
- java - 在 Vaadin 14 中单击按钮时,RadioGroupButton 取消选择值
- linear-programming - 线性规划中的约束数
- java - 包理解问题
- html - 根据单选按钮的输入更改值