python - Keras ConvLSTM 优化结果和内存管理
问题描述
我有用于预测值(线性激活)的图像,而且我对使用神经网络还比较陌生。我拥有的图像包含很多细节并且分辨率很高,但是由于它们是卫星图像,因此由于尺寸很大(8090、12894),这是有问题的。我的目标是 0 和 1 之间的标量。
我的目标是使用排序(RNN 通过 ConvLSTM)和卷积来更好地预测值。
我的步骤如下:
- 函数将图像加载到内存中并沿途处理它们,返回数组
- 生成器从上面的数组中读取数据并提供给模型,目前这不是必需的,但如果我能让模型足够小以在 GPU 上运行,它将会是。
- 模型通过生成器从数组中读取。
我想帮助优化模型以实现我的目标,缩小内存中的模型大小但提高准确性。
以下是我的代码的相关部分:
def build_model(frames=seq_len, channels=3, pixels_x=w, pixels_y=h, kernel_sizing=kernel_sizing):
model = Sequential()
model.add(
ConvLSTM2D(filters=16
, kernel_size=kernel_sizing
, strides = 3
, data_format='channels_last'
, return_sequences = False
, activation='relu', input_shape=(frames, pixels_x, pixels_y, channels))
)
model.add(
Conv2D(filters=16
, kernel_size=(3,3)
, activation='relu')
)
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='linear'))
optimizer = tf.keras.optimizers.RMSprop(0.001)
model.compile(
loss = "mse",
optimizer = optimizer,
metrics=['mae', 'mse'])
return model
笔记:
- 通常我必须调整到大约 0.1,这让我失去了细节
- 我已经尝试过 (3,3) 和 (11,11) 的 kernel_sizing 并且想要更多过滤器,但这会占用大量内存
- 根据内核的大小,我可以进一步大步前进,这就是我想知道我是否可以“缩小”内存中模型的大小的地方
- 因为我的目标值是介于 0 和 1 之间的数字,我是否需要使用 MAE 而不是 MSE,因为 MSE 可能会因为非常小的错误而变得奇怪?
- 展平后是否需要更大的致密层?我如何防止它变得巨大?
- 我相信我的许多记忆问题都是由图像太大引起的,但我该如何管理这个问题并保持我的细节水平?
- (从上面)我是否需要剪切图像块并使用块的目标变量增加数据集的大小?如果是这样,这将如何用于测试来自实际预测集的新图像?
解决方案
概括地说,这里有一些很好的建议来改进模型的优化,这样你就可以避免可怕的“ResourceExhaustedError: OOM when allocating tensor”:
- 在 Conv2D 层中使用更大的步幅
- 减少 Dense、Conv2D 层中的神经元数量
- 使用更小的batch_size(或增加steps_per_epoch)
- 使用灰度图像(将有一个通道而不是三个)
- 减少层数
- 使用更多 MaxPooling2D 层,并增加它们的池大小
- 减小图像大小(您可以使用 PIL 或 cv2)
- 应用辍学
- 使用较小的浮点精度,即 np.float16 而不是 np.float32 (最后的手段)
- 如果您使用的是预训练模型,请冻结第一层
希望有帮助
推荐阅读
- python - 尝试将输入转换为列表
- kubernetes - 在 1 个服务多 DC 场景中,流量仍分配到禁用的 POD - Openshift
- operating-system - 没有操作系统的 Grub 是如何编程的?是否可以在未安装其他操作系统的新组装 UEFI PC 中仅安装 grub 引导加载程序?
- pytorch - 将模型的输出作为输入 pytorch 数据加载器传递
- angular - Angular 材质主题
- sql - Postgresql - 使用连接和文本列更新 - 提高性能?
- r - 带有 R 的 htmlTable 的多色列标题
- sql - 如何显示来自 4 个不同表的数据 - 查询设计器
- agda - Agda:重写子表达式
- swiftui - 我们如何为 ForEach SwiftUI 制作自定义步骤?