首页 > 解决方案 > 跨 MRI 切片的最大池化

问题描述

我正在尝试为 MRI 扫描诊断实施机器学习模型。我有形状 (x, 256, 256, 3) 的输入,其中我们有 3 个颜色通道,其中 x 是序列中的切片数。我阅读了MRNet论文,我想在 TensorFlow Keras 中实现类似的架构。我不想使用 AlexNet 特征提取器,而是使用 VGG16。

论文中的模型架构:

我们预测系统的主要构建块是 MRNet,一个卷积神经网络 (CNN),将 3 维 MRI 系列映射到概率 [15](图 2)。MRNet 的输入尺寸为 s × 3 × 256 × 256,其中 s 是 MRI 系列中的图像数量(3 是颜色通道的数量)。首先,每个二维 MRI 图像切片通过基于 AlexNet 的特征提取器,以获得包含每个切片特征的 × 256 × 7 × 7 张量。然后应用全局平均池化层将这些特征减少到 s × 256。然后我们跨切片应用最大池化以获得 256 维向量,该向量被传递给全连接层和 sigmoid 激活函数以获得预测0 到 1 范围。

到目前为止,一切都很好。我有一个顺序模型,第一步添加了特征提取器,然后应用 GlobalAveragePooling2D() 将特征简化为形状(x,512)。然后我必须在切片上使用 MaxPool,但我没有办法解决这个问题。

feature_extractor = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
model = Sequential()
model.add(feature_extractor)         #output shape: (x, 8, 8, 512)
model.add(GlobalAveragePooling2D())  #output shape: (x, 512)
# Here i have to add a Layer witch Pools over the slices.
model.add(                         )  #output shape(1, 512)

model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

示例 Scan 的形状为 (44, 256, 256, 3)。当它通过 VGG16 时,其特征的维度为 (44, 8, 8, 512)。在 GlobalAverage Pooling 之后,我得到了 (44, 512)。然后,这个二维数组必须以某种方式转换为 (1, 512) 的形状。我的意思是如果我在一个简单的二维 NumPy 数组上进行操作,我需要一个像 np.max 这样的函数在 0 轴上

np.max(x, axis=0)

也许你可以给我一个提示或对此有一个方法。非常感谢你的帮助 :)

################################################# ############################## 编辑:01.05.2021

我玩弄了你的方法@Aaron Keesing,但拟合模型并没有以某种方式训练它。在 25 个 epochs 之后,我仍然具有相同的精度。准确度是我的两个班级的分布(我只是在冠状平面上训练并且异常)

准确度指标

在这种情况下,例如我有 500 个案例,80% 的案例确实有异常,而 20% 没有。

# Dataset train, overall 500 cases
Absolute:
 abnormal  acl  meniscus
1         0    0           184
               1           118
0         0    0           100
1         1    1            63
               0            35
dtype: int64
Relative:
 abnormal  acl  meniscus
1         0    0           0.368
               1           0.236
0         0    0           0.200
1         1    1           0.126
               0           0.070

###########################################################
# Dataset valid, overall 100 cases
Absolute:
 abnormal  acl  meniscus
1         1    1           27
0         0    0           25
1         1    0           23
          0    0           20
               1            5
dtype: int64
Relative:
 abnormal  acl  meniscus
1         1    1           0.27
0         0    0           0.25
1         1    0           0.23
          0    0           0.20
               1           0.05

标签: pythontensorflowmachine-learningkeras

解决方案


您应该能够使用GlobalAveragePooling1D图层。但请注意,它需要一个批次维度。由于您输入的是一系列图像,因此您的输入应该是 5 维的,第一个维度是 batch_size(可以是 1)。

我认为图像 CNN 不适用于 5D 输入,因此您可以使用TimeDistributed图层应用于图像序列,这将为您提供 shape 特征序列(x, 512),然后应用GlobalAveragePooling1D以获得最终特征向量。

所以也许这样的事情可能会奏效。请注意,您必须指定序列中的图像数量x(可以是None):

vgg16 = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
feature_extractor = Sequential()
feature_extractor.add(vgg16)         #output shape: (bs, 8, 8, 512)
feature_extractor.add(GlobalAveragePooling2D())  #output shape: (bs, 512)

model = Sequential()
model.add(TimeDistributed(feature_extractor, input_shape=(x, 256, 256, 3)))  #output shape(bs, x, 512)
# Here i have to add a Layer witch Pools over the slices.
model.add(GlobalAveragePooling1D())   #output shape(bs, 512)
model.add(Dense(1, activation='sigmoid'))   #output shape(bs, 1)

您可以一次只放置一批一个 MRI 序列,这样就可以了bs = 1

这会产生以下模型结构x = None

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
time_distributed (TimeDistri (None, None, 512)         14714688
_________________________________________________________________
global_average_pooling1d (Gl (None, 512)               0
_________________________________________________________________
dense (Dense)                (None, 1)                 513
=================================================================
Total params: 14,715,201
Trainable params: 14,715,201
Non-trainable params: 0
_________________________________________________________________

推荐阅读