tensorflow - 如何保存具有常量的自定义层的 keras 模型?
问题描述
我有一个具有常量的自定义层,但是,我无法保存它显示的模型,它显示以下错误。任何人都知道如何使用自定义层保存 tf.keras 模型,该自定义层在我的自定义 tf.keras 层中有 tf.Variable ?我希望将模糊内核应用到我的卷积神经网络中,该网络试图使卷积层再次从这篇论文中保持移位不变
# This is the layer
class MaxBlurPooling2D(Layer):
def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs):
self.pool_size = pool_size
self.blur_kernel = None
self.kernel_size = kernel_size
super(MaxBlurPooling2D, self).__init__(**kwargs)
def get_config(self):
config = {
'pool_size': self.pool_size,
'blur_kernel': self.blur_kernel,
"kernel_size": self.kernel_size
}
base_config = super(MaxBlurPooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
if self.kernel_size == 3:
bk = tf.constant([[1, 2, 1],
[2, 4, 2],
[1, 2, 1]])
bk = bk / tf.math.reduce_sum(bk)
elif self.kernel_size == 5:
bk = tf.constant([[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]])
bk = bk / tf.math.reduce_sum(bk)
else:
raise ValueError
bk = tf.cast(bk,tf.float32)
bk = tf.repeat(bk, input_shape[3])
bk = tf.reshape(bk, (self.kernel_size, self.kernel_size, input_shape[3], 1))
# blur_init = tf.keras.initializers.constant(bk)
self.blur_kernel = self.add_weight(name='blur_kernel',
shape=(self.kernel_size, self.kernel_size, input_shape[3], 1),
initializer=tf.keras.initializers.constant(bk),
trainable=False)
super(MaxBlurPooling2D, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
x = tf.nn.pool(x, (self.pool_size, self.pool_size),
strides=(1, 1), padding='SAME', pooling_type='MAX', data_format='NHWC')
x = K.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size))
return x
def compute_output_shape(self, input_shape):
return input_shape[0], int(tf.math.ceil(input_shape[1] / 2)), int(tf.math.ceil(input_shape[2] / 2)), input_shape[3]
'Not JSON Serializable:', <tf.Variable 'max_blur_pooling2d_3/blur_kernel:0' shape=(3, 3, 64, 1) dtype=float32, numpy=
array([[[[0.0625],
[0.0625],
[0.0625],
...,
]]]
解决方案
似乎自定义层 get_config 不应该包含self.add_weight的变量
这应该可以解决问题
class MaxBlurPooling2D(L.Layer):
def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs):
self.pool_size = pool_size
self.kernel_size = kernel_size
super(MaxBlurPooling2D, self).__init__(**kwargs)
def get_config(self):
config = {
'pool_size': self.pool_size,
'kernel_size': self.kernel_size
# in this config should not include the variable from add_weight but only the variables in __init__
}
base_config = super(MaxBlurPooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
if self.kernel_size == 3:
bk = tf.constant([[1, 2, 1],
[2, 4, 2],
[1, 2, 1]])
bk = bk / tf.math.reduce_sum(bk)
elif self.kernel_size == 5:
bk = tf.constant([[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]])
bk = bk / tf.math.reduce_sum(bk)
else:
raise ValueError
bk = tf.cast(bk,tf.float32)
bk = tf.repeat(bk, input_shape[3])
bk = tf.reshape(bk, (self.kernel_size, self.kernel_size, input_shape[3], 1))
self.blur_kernel = self.add_weight(name='blur_kernel',
shape=(self.kernel_size, self.kernel_size, input_shape[3], 1),
initializer=tf.keras.initializers.constant(bk),
trainable=False)
super(MaxBlurPooling2D, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
x = tf.nn.pool(x, (self.pool_size, self.pool_size),
strides=(1, 1), padding='SAME', pooling_type='MAX', data_format='NHWC')
x = K.backend.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size))
return x
def compute_output_shape(self, input_shape):
return input_shape[0], int(tf.math.ceil(input_shape[1] / 2)), int(tf.math.ceil(input_shape[2] / 2)), input_shape[3]
推荐阅读
- flutter - Flutter 在 MultiProvider 中使用 StreamingSharedPreferences
- python - 可调用对象作为方法不将 self 作为参数传递
- excel - 为每个循环组合多个范围 vba
- kotlin - 可以在没有协程的ViewModel类中调用DAO类的函数吗
- laravel - getCustomAttribute 中的 Laravel 方法返回 cullection null
- javascript - Highcharts Spiderweb 图表 xAxis 标签在长标签名称上消失
- javascript - 无法将引导 JS 加载到 Electron 应用程序中
- javascript - hibext_instdsigdipv2 cookie 来自哪里?
- c++ - 达到限制后自动旋转值的自定义 qt spinbox
- linux - 如果字符串的出现每行恰好一次,如何删除一行?