python - TF2.6:ValueError:模型无法保存,因为尚未设置输入形状
问题描述
我想在 Google Colab 中使用迁移学习创建自定义模型。
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
from tensorflow.python.keras.applications.xception import Xception
class MyModel(tf.keras.Model):
def __init__(self, input_shape, num_classes=5, dropout_rate=0.5):
super(MyModel, self).__init__()
self.weight_dict = {}
self.weight_dict['backbone'] = Xception(input_shape=input_shape, weights='imagenet', include_top=False)
self.weight_dict['outputs'] = Conv2D(num_classes, (1, 1), padding="same", activation="softmax")
self.build((None,) + input_shape)
def call(self, inputs, training=False):
self.weight_dict['backbone'].trainable = False
x = self.weight_dict['backbone'](inputs)
x = self.weight_dict['outputs'](x)
return x
model = MyModel(input_shape=(256, 256, 3))
model.save('./saved')
但是,我遇到了这个错误:
ValueError: Model `<__main__.MyModel object at 0x7fc66134bdd0>` cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`. To manually set the shapes, call `model.build(input_shape)`.
是的,没有调用.fit()
or .predict()
。.build
但是在类的__init__()
方法中有一个调用。我是什么做的?
解决方案
如果该层尚未构建,compute_output_shape将在该层上调用构建。这假设该层稍后将与与提供的输入形状匹配的输入一起使用。
工作代码如下图
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.applications.xception import Xception
class MyModel(tf.keras.Model):
def __init__(self, input_shape, num_classes=5, dropout_rate=0.5):
super(MyModel, self).__init__()
self.weight_dict = {}
self.weight_dict['backbone'] = Xception(input_shape=input_shape, weights='imagenet', include_top=False)
self.weight_dict['outputs'] = Conv2D(num_classes, (1, 1), padding="same", activation="softmax")
self.build((None,) + input_shape)
def call(self, inputs, training=False):
self.weight_dict['backbone'].trainable = False
x = self.weight_dict['backbone'](inputs)
x = self.weight_dict['outputs'](x)
return x
input_shape=(256, 256, 3)
model=MyModel(input_shape)
model.compute_output_shape(input_shape=(None, 256, 256, 3))
model.save('./saved')
输出:
2.6.0
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 1s 0us/step
INFO:tensorflow:Assets written to: ./saved/assets
有关更多信息,您可以参考此处。
推荐阅读
- jpeg - JPG 文件上的图像标题损坏
- tensorflow-federated - DP-FedAvg 中客户更新的剪辑
- flutter - 如何在颤动中制作空安全版本的列表?
- python - 在勇敢的浏览器中禁用跨域读取阻止 (CORB)
- reactjs - findDOMNode 在 StrictMode 中已弃用 Reactjs 中出现错误
- java - 为什么zuul网关应用程序需要这么长时间才能以优雅的方式关闭
- asp.net-core - 为什么 UseSqlCe() 支持从最新版本的 EntityFrameworkCore.SqlServerCompact40 中删除?
- vb.net - 如何在 VB.Net 中使用 HttpClient 设置 Cookie
- android - 如何修复·libtest_x86_64-unknown-linux-gnu" 模块依赖?
- mongoose - 适配器类型“mongoose”不支持字段类型“CloudinaryImage”