python - keras 分割 InvalidArgumentError:不兼容的形状:[32,256,256,3] 与 [32,256,256,4]
问题描述
我正在尝试训练 UNET 以检测自然灾害后的道路损坏。我的图像是256x256x3
,面具是256x256x1
。像素对应于 4 类:0 = 背景,1 = 没有损坏的道路,2 = 损坏的道路和 3 = 遮挡视线的树叶。它看起来像这样 - (在图像中只能看到 3 个类):
我将这些图像放在具有以下结构的文件夹中。我也有相同结构的验证数据
-- train images
-- img
-- 1.png
-- 2.png
-- train masks
-- img
-- 1_mask.png
-- 2_mask.png
我的训练和验证数据生成器如下
image_data_gen = image.ImageDataGenerator(rotation_range = 40, horizontal_flip = True,
vertical_flip = True, zoom_range = 0.2,
shear_range = 0.2,width_shift_range = 0.2,
height_shift_range = 0.2)
mask_data_gen = image.ImageDataGenerator(rotation_range = 40, horizontal_flip = True,
vertical_flip = True, zoom_range = 0.2,
shear_range = 0.2,width_shift_range = 0.2,
height_shift_range = 0.2)
validimg_data_gen = image.ImageDataGenerator()
validmask_data_gen = image.ImageDataGenerator()
image_array_gen = image_data_gen.flow_from_directory(directory=train_images_path, class_mode = None,
target_size = (256,256), seed = 909)
mask_array_gen = mask_data_gen.flow_from_directory(directory=train_segs_path, class_mode = None,
target_size = (256,256), seed = 909)
valid_image_array_gen = validimg_data_gen.flow_from_directory(directory= val_images_path, class_mode = None,
target_size = (256,256), seed = 909)
valid_mask_array_gen = validmask_data_gen.flow_from_directory(directory= val_segs_path, class_mode = None,
target_size = (256,256), seed = 909)
# combine generators into one which yields image and masks
train_generator = zip(image_array_gen, mask_array_gen)
valid_generator = zip(valid_image_array_gen, valid_mask_array_gen)
我知道,因为这是一个分段问题,所以 class_mode 应该是 None。当我运行上面的单元格时,会Found x images belonging to 1 classes.
出现消息,其中 x 是我在子文件夹“img”中拥有的训练和验证图像的数量。我认为错误可能是我在“img”子文件夹中有我的数据,而 keras 认为所有图像都对应于 img 类,而实际上有 4 个像素类。但是,如果我将数据放在火车图像和火车掩码文件夹中,我会收到消息Found 0 images belonging to 0 classes.
当我尝试训练我的模型时,results = model.fit_generator(train_generator, steps_per_epoch=int(train_samples/batch_size),epochs=30, validation_data=valid_generator,validation_steps=int(valid_samples/batch_size))
我得到了错误:
InvalidArgumentError: Incompatible shapes: [32,256,256,3] vs. [32,256,256,4]
[[node gradient_tape/categorical_crossentropy/mul/BroadcastGradientArgs (defined at <ipython-input-120-22f04a70298f>:3) ]] [Op:__inference_train_function_19597]
Function call stack:
train_function
我用来定义模型的代码在这个问题的末尾。请注意,如果我更改最后一层,conv10 = Conv2D(1, 1, activation = 'softmax')(conv9)
错误就会消失,并且训练完成没有问题。这也是为什么我认为错误可能是我在子文件夹 img 中有我的数据,但是我该怎么做才能指定像素类为 4 的模型而不会出现此错误?
模型
def unet(pretrained_weights = None,input_size = (256,256,3)):
inputs = Input(input_size)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = concatenate([drop4,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3,up7], axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2,up8], axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1,up9], axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(4, 1, activation = 'softmax')(conv9)
model = Model(inputs = inputs, outputs = conv10)
return(model)
model = unet()
model.compile(optimizer='adam', loss='categorical_crossentropy' ,metrics=['categorical_accuracy'])
解决方案
您的 train_images _path 应该是包含类子目录的目录的路径。对于您在从目录流中使用的其他目录也是如此。例如,如果您具有如下所示的目录结构来对狗和猫进行分类
c:
---train_images
---img
--- cats
--- dogs
then in flow from directory your path would be
r'c:\train_images\img'
来自目录的流将创建猫和狗类
推荐阅读
- java - 从另一个 Java 项目调用主方法时遇到“java.lang.NoClassDefFoundError”
- python - 将 Pandas 数据框列和索引转换为值
- flutter - 当我更改 lib 文件夹中的文件夹名称时,项目不运行
- python - 如何将白色添加到 wxpython 框
- java - smallrye.jwt.sign.key-location 是 application.properties 中的未知属性
- html - Bootstrap导航栏下拉菜单不下拉?
- javascript - react create app, typescript unit test with mocha and chai 什么是支持 es6 模块的正确设置?
- vb.net - 将数据从一个表单传递到另一个表单的文本框
- excel - 尝试构建一个宏以将文本框从 Powerpoint 复制到 Ex
- cplex - 有没有办法在 CPLEX 中迭代二维数组