python - Tensorflow - 深度到空间后与conv2d不兼容的形状
问题描述
我在实现超分辨率模型时遇到问题
class SRNet(Model):
def __init__(self, scale=4):
super(SRNet, self).__init__()
self.scale = scale
self.conv1 = Sequential([
layers.Conv2D(filters=64, kernel_size=3,
strides=(1, 1), padding="same", data_format="channels_first"),
layers.ReLU(),
])
self.residualBlocks = Sequential(
[ResidualBlock() for _ in range(16)])
self.convUp = Sequential([
layers.Conv2DTranspose(filters=64, kernel_size=3, strides=(
2, 2), padding="same", data_format="channels_first"),
layers.ReLU(),
layers.Conv2DTranspose(filters=64, kernel_size=3, strides=(
2, 2), padding="same", data_format="channels_first"),
layers.ReLU(),
])
self.reluAfterPixleShuffle = layers.ReLU()
self.convOut = layers.Conv2D(
filters=3, kernel_size=3, strides=(1, 1), padding="same", data_format="channels_first", input_shape=(4, 1440, 2560)) # (kernel, kernel, channel, output)
def call(self, lrCur_hrPrevTran):
lrCur, hrPrevTran = lrCur_hrPrevTran
x = tf.concat([lrCur, hrPrevTran], axis=1)
x = self.conv1(x)
x = self.residualBlocks(x)
x = self.convUp(x)
# pixel shuffle
Subpixel_layer = Lambda(lambda x: tf.nn.depth_to_space(
x, self.scale, data_format="NCHW"))
x = Subpixel_layer(inputs=x)
x = self.reluAfterPixleShuffle(x)
x = self.convOut(x)
return x
错误
/usr/src/app/generator.py:164 call *
x = self.convOut(x)
ValueError: Tensor's shape (3, 3, 64, 3) is not compatible with supplied shape (3, 3, 4, 3)
阅读错误后,我知道 (3, 3, 4, 3) 是 (kernel size, kernel size, channel, output) 意味着只有输入通道不正确
所以我打印出输入的形状
# after pixel shuffle before convOut
print(x.shape)
>>> (1, 4, 1440, 2560) (batch size, channel, height, width)
但是 x 之后的形状pixel shuffle (depth_to_space)
是 (1, 4, 1440, 2560) 通道值是 4 与convOut
需要相同
问题是为什么输入的通道从 4 变为 64 作为错误?
解决方案
我找到了解决方案
首先,在模型
的实现和测试期间,我在训练时使用检查点来保存模型权重,我改变了一些层,所以输入大小也改变了,但我的权重仍然记得之前的输入大小检查点
所以我删除了检查点文件夹,然后一切正常
推荐阅读
- ssh - 无法为 ssh 生成 U2F 公钥/私钥对:FIDO_ERR_RX
- excel - Userform - VBA 代码编译错误:未定义用户定义的类型
- c# - 如何获得统一的旋转坐标之一?
- bash - 执行用户定义 shell 命令时出现问题
- javascript - TypeError:使用 react-context api 时调度不是函数
- python-3.x - 在 Dash Python 中下拉
- assembly - I/O 端口寻址
- php - 如何使用 php 将主机可用性插槽推送到日历中?
- python - 图像上的 OpenCv Python 操作,有没有办法做到这一点?
- c# - 如何除以 0 不是“除以零”异常?