python - Carvana 数据集上的 Unet 架构
问题描述
我关注了 zhixuhao 的unet 实现和 kaggle 的另一个实现。
这两个模型彼此并没有太大的不同,除了前者有一些额外的层,因此有近 3000 万个参数。
我的问题是我无法让任何一个模型表现良好(我的意思是两个模型都损失了 -800 东西),就 binary_crossentropy 损失和准确性或 dice_coef 作为指标而言。请帮我找出我哪里出错了. 以下是我的一些怀疑:
1)我注意到一件有趣的事情,dice_coef 在一个时期内达到了 1.9(这应该是不可能的,因为它应该小于 1)。所以这是来自 kaggle 链接的 dice_coeff 函数
def dice_coef(y_true, y_pred, smooth=0):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection=K.sum(y_true_f * y_pred_f)
return(2. * intersection + smooth) / ((K.sum(y_true_f) + K.sum(y_pred_f)) + smooth)
2) flow_from_directory()
keras 中给出的函数默认不读取 .gif 文件(掩码图像为 .gif 格式)。所以我遵循了这个建议并在keras/preprocessing/image.py
. flow_from_directory()
然后在通过我给出的读取图像时color_mode = 'grayscale'
,目标图像具有 1 个通道,因为 UNet 架构的最后一层是 1 个通道输出。如果我自己通过阅读图像skimage.io.imread()
,则 gif 图像是大小的(1024, 1024)
,即 1 个通道。
3)我还认为图像增强可能是负责任的。我主要使用了keras的默认增强。这是整个图像读取和增强部分
input_shape = (1024, 1024, 3)
batch_size = 4
# we create two instances with the same arguments
data_gen_args = dict(rotation_range=90,
width_shift_range=0.1,
height_shift_range=0.1)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
# Provide the same seed and keyword arguments to the fit and flow methods
seed = 1
image_generator = image_datagen.flow_from_directory(
'Carvana/train',
target_size = (input_shape[0], input_shape[1]),
batch_size = batch_size,
class_mode=None,
seed=seed)
mask_generator = mask_datagen.flow_from_directory(
'Carvana/train_masks',
target_size = (input_shape[0], input_shape[1]),
batch_size = batch_size,
color_mode = 'grayscale',
class_mode=None,
seed=seed)
# combine generators into one which yields image and masks
train_generator = zip(image_generator, mask_generator)
model2 = unet(input_shape)
model2.fit_generator(
train_generator,
steps_per_epoch=50,
epochs=2)
训练输出是
Found 5088 images belonging to 1 classes.
Found 5088 images belonging to 1 classes.
Epoch 1/2
50/50 [==============================] - 66s 1s/step - loss: -724.1043 - dice_coef: 1.8661
Epoch 2/2
50/50 [==============================] - 64s 1s/step - loss: -829.2828 - dice_coef: 1.9626
最后这是来自这个 kaggle 内核的整个网络,我唯一的修改是在连接层中更改输入channel-last
和输出axis = 3
axis = 1
def unet(input_shape):
input_ = Input(input_shape)
conv0 = Conv2D(8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_)
conv0 = Conv2D(8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv0)
pool0 = MaxPooling2D(pool_size=(2, 2))(conv0)
conv1 = Conv2D(16, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool0)
conv1 = Conv2D(16, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv2 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
conv3 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
up6 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv5))
merge6 = Concatenate(axis = 3)([conv4,up6])
conv6 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = Concatenate(axis = 3)([conv3,up7])
conv7 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = Concatenate(axis = 3)([conv2,up8])
conv8 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(16, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = Concatenate(axis = 3)([conv1,up9])
conv9 = Conv2D(16, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(16, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
up10 = Conv2D(16, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv9))
conv10 = Conv2D(8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up10)
conv11 = Conv2D(1, 1, activation = 'sigmoid')(conv10)
model = Model(input = input_, outputs = conv11)
model.compile(optimizer= Adam(lr=0.0005), loss='binary_crossentropy', metrics=[dice_coef])
return model
最后,这是我在训练数据集中的第一张图像上仅测试模型的掩码输出
pic = cv2.resize(io.imread('Carvana/train/train/0cdf5b5d0ce1_01.jpg'), input_shape[:2])
pic = pic.reshape(1, input_shape[0], input_shape[1], input_shape[2])
res = model2.predict(pic)
print(res[0].shape)
res = np.array(res[0])
r = res * 200
g = res * 1
b = res * 70
res = np.concatenate((r, g, b), axis = 2)
io.imshow(res)
对于这么长的帖子,我很抱歉,但我无法指出我所犯的确切错误。任何帮助深表感谢。
解决方案
推荐阅读
- javascript - 如何查找 ReactInstance 的所有子图像
- php - html-php 在同一页面上发布表单 - iframe 问题
- android - Firebase 存储异常:发生未知错误,请检查服务器响应的 http 结果代码和内部异常
- c++ - 如何在 C++ 中跨多个进程使用共享向量
- php - php mysqli 选择表中的数据而不是不同的表中的数据
- blazor-client-side - 如何向 Blazor WebAssembly 3.2.0 Preview 3 应用程序添加其他声明
- ios - 加载没有名称的数组 (JSON) (Swift 5)
- javascript - 使用 Javascript 加载新页面
- c# - 制作可以使用多种分辨率的交互式地图
- laravel - 如何在 Laravel 急切加载中使用 if 条件?