tensorflow - 当我使用 tensorflow2 的子类化 API 时,出现了一些奇怪的问题:有些层可以重用,有些不能,为什么
问题描述
在下面的两个例子中,都是使用子类化来构建相同的模型。重用层会出现一些奇怪的问题。一个不能重用所有层,另一个不能重用部分层,例如卷积, BatchNormization,但可以重用激活层。
为什么?
张量流版本:2.0.0
1. 使用 tensorflow 中已有的层。
所有层都不能重用,例如卷积,BatchNormailzation,Activation。
在下面的代码中,当我在调用函数中将 'conv2' 更改为 'conv' 或将 'bn2' 更改为 'bn' 或 'ac2' 为 'ac' 时,会引发错误。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU
class Models(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv = Conv2D(16, (3, 3), padding='same')
self.bn = BatchNormalization()
self.ac = ReLU()
self.conv2 = Conv2D(32, (3, 3), padding='same')
self.bn2 = BatchNormalization()
self.ac2 = ReLU()
def call(self, x, **kwargs):
x = self.conv(x)
x = self.bn(x)
x = self.ac(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.ac2(x)
return x
m = Models()
m.build(input_shape=(2, 8, 8, 3))
m.summary()
在重新启动层时会抛出一些错误,例如:
- 重用 BatchNormalization 层:
ValueError: Input 0 of layer batch_normalization is incompatible with the layer: expected axis 3 of input shape to have value 16 but received input with shape [2, 8, 8, 32]
- 复用卷积层:
ValueError: Input 0 of layer conv2d is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [2, 8, 8, 16]
- 重用激活层:
ValueError: You tried to call `count_params` on re_lu_1, but the layer isn't built. You can build it manually via: `re_lu_1.build(batch_input_shape)`.
2. 使用从 tensorflow 扩展而来的自定义层。
在下面的代码中,重用卷积/BactchNorization 层的结果与之前的代码类似,但是激活层可以重用!
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import layers
class DoubleConv(layers.Layer):
def __init__(self, mid_kernel_numbers, out_kernel_number):
"""
初始化含有两个卷积的卷积块
:param mid_kernel_numbers: 中间特征图的通道数
:param out_kernel_number: 输出特征图的通道数
"""
super().__init__()
self.conv1 = layers.Conv2D(mid_kernel_numbers, (3, 3), padding='same')
self.conv2 = layers.Conv2D(out_kernel_number, (3, 3), padding='same')
self.bn = layers.BatchNormalization()
self.bn2 = layers.BatchNormalization()
self.ac = layers.ReLU()
self.ac2 = layers.ReLU()
def call(self, input, **kwargs):
"""正向传播"""
x = self.conv1(input)
x = self.bn(x)
x = self.ac(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.ac2(x)
return x
class Model(tf.keras.Model):
def __init__(self):
"""
构建模型的类
"""
super().__init__()
# 初始化卷积块
self.block = DoubleConv(16, 32)
def call(self, x, **kwargs):
x = self.block(x)
return x
m = Model()
m.build(input_shape=(2, 8, 8, 3))
m.summary()
在重新启动层时会抛出一些错误,例如:
- 重用 BatchNormalization 层:
ValueError: Input 0 of layer batch_normalization is incompatible with the layer: expected axis 3 of input shape to have value 16 but received input with shape [2, 8, 8, 32]
- 复用卷积层:
AttributeError: 'DoubleConv' object has no attribute 'conv'
我的猜测有两种可能:
一个与层的名称有关。另一个与参数有关。激活层不需要参数。 但是这些并不能解释为什么会有差异。
解决方案
推荐阅读
- javascript - Vue Route BroadCrumbs 值正在刷新页面
- javascript - 我的目标是获得一个按钮,当单击该按钮时,应该显示表格,但预先显示数据并且按钮什么也不做
- python - 下载以 HTML 格式下载的 Jupyter 笔记本时如何更改保存的文件夹?
- c# - 动态自定义组件列表上的 Blazor 双向绑定
- sql - 如何检索具有两个具有两个不同值的列的客户端
- python - 从字典创建嵌套列表
- java - HttpsURLConnection 返回 --> java.io.IOException: com.android.okhttp.Address 上的流意外结束
- joomla - Joomla 联系人搜索错误(0 - 类 stdClass 的对象无法转换为字符串)
- python - python kivyMD pos_hint 无法正常工作
- reactjs - 在 ApexCharts 中仅更改一个 x 轴标签的字体大小?