首页 > 解决方案 > 为什么 Tensorflow “save_model” 失败了?

问题描述

我正在使用 Tensorflow 2.0.0,并尝试保存模型。这是我正在使用的代码(感谢@m-innat 建议简化示例模型)

class SimpleModel( tf.keras.Model ):
    def __init__( self, **kwargs ):
        super( SimpleModel, self ).__init__( **kwargs )
        self.conv = tf.keras.layers.Conv1D( filters = 5, kernel_size = 3, padding = "SAME" )
        self.dense = tf.keras.layers.Dense( 1 )
    def call( self, x ):
        x = self.conv( x )
        x = self.dense( x )
        return x

simple_model = SimpleModel()

input_shape = ( 3, 4, 5 )
x = tf.random.normal( shape = input_shape )
y = tf.random.normal( shape = ( 3, 4, 1 ) )
y_pred = simple_model( x )
print( "y_pred", y_pred )

tf.keras.models.save_model( translation_model, 
    "/content/gdrive/MyDrive/SimpleModel.tf", save_format = "tf" )

但是,save_model调用给出了错误:

AttributeError: 'NoneType' object has no attribute 'shape'

调用堆栈中没有任何内容表明潜在的问题是什么。你能帮忙吗?

标签: tensorflow

解决方案


该错误与未设置图层的输入形状有关。这可以通过调用一次方法simple_model.fitsimple_model.predict.

例如,在您的代码中,您可以调用y_pred = simple_model.predict( x ).

这样,当我在下面的代码中检查时,模型就被正确保存了。

import tensorflow as tf

class SimpleModel( tf.keras.Model ):
    def __init__( self, **kwargs ):
        super( SimpleModel, self ).__init__( **kwargs )
        self.conv = tf.keras.layers.Conv1D( filters = 5, kernel_size = 3, padding = "SAME" )
        self.dense = tf.keras.layers.Dense( 1 )
    def call( self, x ):
        x = self.conv( x )
        x = self.dense( x )
        return x

simple_model = SimpleModel()

input_shape = ( 3, 4, 5 )
x = tf.random.normal( shape = input_shape )
y = tf.random.normal( shape = ( 3, 4, 1 ) )
y_pred = simple_model.predict( x )
print( "y_pred", y_pred )

tf.keras.models.save_model( simple_model, 
    "/content/gdrive/MyDrive/SimpleModel.tf", save_format = "tf" )

# Output:
# y_pred [[[-0.4533468 ]
#  [ 1.3261242 ]
#  [-1.0296338 ]
#  [-1.1136482 ]] ...

model = tf.keras.models.load_model('/content/gdrive/MyDrive/SimpleModel.tf')
model.predict(x)

# Output:
#array([[[-0.4533468 ],
#        [ 1.3261242 ],
#        [-1.0296338 ],
#        [-1.1136482 ]], ...

推荐阅读