tensorflow - Keras CNN 如何减少大图像尺寸的 gpu 内存使用量?
问题描述
我正在尝试训练一个 cnn-lstm 模型,我的样本图像大小为 640x640。
我有一个 GTX 1080 ti 11GB。
我正在使用带有 tensorflow 后端的 Keras。
这是模型。
img_input_1 = Input(shape=(1, n_width, n_height, n_channels))
conv_1 = TimeDistributed(Conv2D(96, (11,11), activation='relu', padding='same'))(img_input_1)
pool_1 = TimeDistributed(MaxPooling2D((3,3)))(conv_1)
conv_2 = TimeDistributed(Conv2D(128, (11,11), activation='relu', padding='same'))(pool_1)
flat_1 = TimeDistributed(Flatten())(conv_2)
dense_1 = TimeDistributed(Dense(4096, activation='relu'))(flat_1)
drop_1 = TimeDistributed(Dropout(0.5))(dense_1)
lstm_1 = LSTM(17, activation='linear')(drop_1)
dense_2 = Dense(4096, activation='relu')(lstm_1)
dense_output_2 = Dense(1, activation='sigmoid')(dense_2)
model = Model(inputs=img_input_1, outputs=dense_output_2)
op = optimizers.Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.001)
model.compile(loss='mean_absolute_error', optimizer=op, metrics=['accuracy'])
model.fit(X, Y, epochs=3, batch_size=1)
现在使用这个模型,我只能在图像大小调整为 60x60 时使用训练数据,任何更大并且我用完 GPU 内存。
我想使用尽可能大的尺寸,因为我想保留尽可能多的歧视性信息。(y 标签将是 0 - 640 之间的鼠标屏幕坐标)
在许多其他人中,我找到了这个答案: https ://ai.stackexchange.com/questions/3938/how-to-handle-images-of-large-sizes-in-cnn
尽管我不确定如何“限制您的 CNN”或“在每个时期流式传输您的数据”,或者这些是否会有所帮助。
如何减少使用的内存量,以便增加图像大小?
是否有可能牺牲训练时间/计算速度来支持更高分辨率的数据,同时保持模型的有效性?
注意:以上模型不是最终的,只是一个基本的支出。
解决方案
您的Dense
层可能正在破坏训练。为了提供一些上下文,假设您使用的是640x640x3
图像大小。让我们也忘记LSTM
层,假设这是一个非时间序列任务(当然,时间序列问题的复杂性会变得更糟)。
这是输出大小。
Conv1
->640x640x96
Maxpool1
->210x210x96
(应用程序)Conv2
->210x210x128
现在到了瓶颈。然后,您flatten()
将输出并将其发送到Dense
图层。这个密集层有210x210x128x4096
参数(即23,121,100,800
)。假设32-bit
精度,您的密集层将占用大约 86GB(我希望我的计算是正确的,但我向您保证这不是一个小数字)。
所以你的选择很少。
- 首先也是最明显的,减小
Dense
图层大小。 - 减少小批量的大小
- 减少
Conv
层的通道深度。 - 您可能想考虑是否真的希望输入为
640x640x3
. 根据您要实现的目标,您可能可以使用较小的图像来做到这一点。
推荐阅读
- ruby-on-rails - 将多个数组 + 变量组合成单个 JSON 对象
- sql - 更新 SQL 中的多个值
- java - Java - 递归双阶乘算法
- azure - 标准连接器转向高级连接器。逻辑应用程序是否也受到影响
- file - 请求正文中的动态文件
- sql - How to send a query request to SQL Server via tcp ip (with a general tool, such as Packet Sender) on the same PC?
- git - VSCode:区别“创建分支”和“创建分支”
- c++ - 为什么使用 _Bool 和 bool 而不是 int8_t 或 char?
- sql - 仅当列具有确定的值时,我如何才能返回不为空的记录
- swift - 无法将 CGFLoat 转换为 CGPoint