首页 > 解决方案 > 关键点检测实现

问题描述

我正在研究关键点检测,特别是识别猫的左眼、右眼和嘴巴。

例如,有一张 64x64 的猫图像。

在此处输入图像描述

我在猫的左眼、右眼和嘴巴上创建了一个带有圆圈的图像。

在此处输入图像描述

我为普通图像创建了 ImageDataGenerators:

train_datagen = ImageDataGenerator(
    featurewise_center=False, 
    featurewise_std_normalization=False, 
    width_shift_range=0.1, 
    height_shift_range=0.1, 
    zoom_range=0.2)

train_generator = train_datagen.flow_from_directory(
    train_path,
    target_size=(64, 64),
    batch_size=32,
    seed=1,
    subset='training')

validation_generator = train_datagen.flow_from_directory(
    val_path,
    target_size=(64, 64),
    batch_size=32,
    seed=1,
    subset='validation')

也适用于具有关键点的图像:

heatmap_train_datagen = ImageDataGenerator(
    featurewise_center=False, 
    featurewise_std_normalization=False, 
    width_shift_range=0.1, 
    height_shift_range=0.1, 
    zoom_range=0.2)

heatmap_train_generator = heatmap_train_datagen.flow_from_directory(
    heatmap_train_path,
    target_size=(img_height, img_width),
    batch_size=32,
    seed=1,
    subset='training')

heatmap_validation_generator = heatmap_train_datagen.flow_from_directory(
    heatmap_val_path,
    target_size=(img_height, img_width),
    batch_size=32,
    seed=1,
    subset='validation')

然后将它们拉在一起:

zipped_train_generator = zip(train_generator, heatmap_train_generator)

我有这个模型:

Layer (type)                 Output Shape              Param #    
input_1 (InputLayer)         [(None, 64, 64, 3)]       0          
block1_conv1 (Conv2D)        (None, 64, 64, 64)        1792       
block1_conv2 (Conv2D)        (None, 64, 64, 64)        36928      
block1_pool (MaxPooling2D)   (None, 32, 32, 64)        0          
block2_conv1 (Conv2D)        (None, 32, 32, 128)       73856     
block2_conv2 (Conv2D)        (None, 32, 32, 128)       147584     
bottleneck_1 (Conv2D)        (None, 32, 32, 160)       5243040    
bottleneck_2 (Conv2D)        (None, 32, 32, 160)       25760      
upsample_1 (Conv2DTranspose) (None, 64, 64, 3)         1920       
Total params: 5,530,880  
Trainable params: 5,530,880  
Non-trainable params: 0

我正在训练模型:

history = model.fit((pair for pair in zipped_train_generator),
                    epochs=30,
                    validation_data = (validation_generator,heatmap_validation_generator)
                  )

它工作了1个多小时,似乎永远不会结束: 在此处输入图像描述

这是检测关键点的正确方法吗?如果不是,我在哪里做错了?我应该如何实施?

标签: tensorflowopencvmachine-learningkerascomputer-vision

解决方案


推荐阅读