tensorflow - 为什么 Tensorflow “save_model” 失败了?
问题描述
我正在使用 Tensorflow 2.0.0,并尝试保存模型。这是我正在使用的代码(感谢@m-innat 建议简化示例模型)
class SimpleModel( tf.keras.Model ):
def __init__( self, **kwargs ):
super( SimpleModel, self ).__init__( **kwargs )
self.conv = tf.keras.layers.Conv1D( filters = 5, kernel_size = 3, padding = "SAME" )
self.dense = tf.keras.layers.Dense( 1 )
def call( self, x ):
x = self.conv( x )
x = self.dense( x )
return x
simple_model = SimpleModel()
input_shape = ( 3, 4, 5 )
x = tf.random.normal( shape = input_shape )
y = tf.random.normal( shape = ( 3, 4, 1 ) )
y_pred = simple_model( x )
print( "y_pred", y_pred )
tf.keras.models.save_model( translation_model,
"/content/gdrive/MyDrive/SimpleModel.tf", save_format = "tf" )
但是,save_model
调用给出了错误:
AttributeError: 'NoneType' object has no attribute 'shape'
调用堆栈中没有任何内容表明潜在的问题是什么。你能帮忙吗?
解决方案
该错误与未设置图层的输入形状有关。这可以通过调用一次方法simple_model.fit
或simple_model.predict
.
例如,在您的代码中,您可以调用y_pred = simple_model.predict( x )
.
这样,当我在下面的代码中检查时,模型就被正确保存了。
import tensorflow as tf
class SimpleModel( tf.keras.Model ):
def __init__( self, **kwargs ):
super( SimpleModel, self ).__init__( **kwargs )
self.conv = tf.keras.layers.Conv1D( filters = 5, kernel_size = 3, padding = "SAME" )
self.dense = tf.keras.layers.Dense( 1 )
def call( self, x ):
x = self.conv( x )
x = self.dense( x )
return x
simple_model = SimpleModel()
input_shape = ( 3, 4, 5 )
x = tf.random.normal( shape = input_shape )
y = tf.random.normal( shape = ( 3, 4, 1 ) )
y_pred = simple_model.predict( x )
print( "y_pred", y_pred )
tf.keras.models.save_model( simple_model,
"/content/gdrive/MyDrive/SimpleModel.tf", save_format = "tf" )
# Output:
# y_pred [[[-0.4533468 ]
# [ 1.3261242 ]
# [-1.0296338 ]
# [-1.1136482 ]] ...
model = tf.keras.models.load_model('/content/gdrive/MyDrive/SimpleModel.tf')
model.predict(x)
# Output:
#array([[[-0.4533468 ],
# [ 1.3261242 ],
# [-1.0296338 ],
# [-1.1136482 ]], ...
推荐阅读
- ios - 删除 tableView 中的项目会从存储的值中删除错误的对应文件(swift NSCoding)
- java - 有什么方法可以为多个 Runner 类运行的多个功能文件创建一个 Cucumber 报告?
- jmeter - JMeter正则表达式提取器一对一映射
- javascript - 使用 require.context 后如何动态加载 Vue 组件?
- android - Android 矢量绘图
不支持, 不支持 - azure - 如何使用PowerShell检查Azure虚拟机上是否启用了备份
- c# - 即使在为操作设置超时属性后超时异常
- angular5 - Angular 5 Observable 多次调用
- javascript - 将 id 添加到对象数组
- jquery - Prototype .observe 说 attachEvent 不是函数