python - 如何获得所需的尺寸作为 Tensorflow(Keras)输出?
问题描述
总的来说,我对机器学习和 Tensorflow 还是很陌生。我正在尝试为模型提供 800x800x1 尺寸的图像,并尝试获取 800x800x1 尺寸的图像作为输出。
我试图给模型这个图像,
并尝试重新创建下面给出的所需版本,
到目前为止,我的模型是,
model = models.Sequential()
model.add(layers.Conv2D(32, (5, 5), activation='relu', input_shape=(800, 800, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
以及模型的总结,
Layer (type) Output Shape Param #
=================================================================
conv2d_70 (Conv2D) (None, 796, 796, 32) 832
_________________________________________________________________
max_pooling2d_36 (MaxPooling (None, 398, 398, 32) 0
_________________________________________________________________
conv2d_71 (Conv2D) (None, 396, 396, 64) 18496
_________________________________________________________________
max_pooling2d_37 (MaxPooling (None, 198, 198, 64) 0
_________________________________________________________________
conv2d_72 (Conv2D) (None, 196, 196, 128) 73856
=================================================================
Total params: 93,184
Trainable params: 93,184
Non-trainable params: 0
从最后一个卷积层输出可以看出,形状为(196,196,128)。所以我想知道如何在输出中实现所需的 800x800x1 尺寸。我知道问题出在我的层面,但没有必要的知识来研究问题。非常感谢任何想法或指导方针。
谢谢你。祝你今天过得愉快!
解决方案
确切地说,这就是在 Keras 中实现转置卷积的方式。在这里你应该得到最后一层输出(800x800x1)
。在 Conv2DTranspose中设置非常重要,strides
因为这就是实现上采样的方式。
from tensorflow.keras import layers, models
model = models.Sequential()
model.add(layers.Conv2D(32, (5, 5), activation='relu', padding='same', input_shape=(800, 800, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(layers.Conv2DTranspose(64, (3, 3), strides=(2,2), padding='same', activation='relu'))
model.add(layers.Conv2DTranspose(1, (3, 3), strides=(2,2), padding='same', activation='relu'))
model.summary()
推荐阅读
- php - 在php中使用相同的数组键对多维数组进行排序
- javascript - 对键值的双数据绑定对象,用于下拉列表
- java - 为什么我们在资源和授权服务器中需要完全相同的配置
- angular - 遵循 i18n 与 JIT 编译器集成的官方教程,网站仍然是英文
- python - 使用 sys.odcinumberlist 作为参数从 python 执行 PL/SQL 过程
- java - 如何将文本文件中的不同数据存储到 ArrayLists 中?
- security - 将任何用户文件上传到 S3 存储桶是否安全?
- javascript - 在严格模式下使用 delete 删除对象条目
- android - @Parcelize 和枚举类 - 重载解析歧义
- php - Sql 查询未在 php 中产生输出