首页 > 技术文章 > 基于 Tensorflow 实现 Mobilenet V1 并基于 CFAR-10 数据训练

TuringEmmy 2020-03-23 21:24 原文

基于 Tensorflow 实现 Mobilenet V1 并基于 CFAR-10 数据训练

论文:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

image-20200323144654547

深度可分离卷积

将标准卷积分解为深度卷积和逐点卷积

传统卷积计算量

D_k^2MD_f^2N

深度可分离卷积计算量

D_k^2MD_f^2+D_f^2MN

image-20200323144610736

  • 轻量级卷积神经网络
  • 更少的参数、更小的计算量却拥有不俗的性能
  • 空间可分离卷积
模型搭建
def conv_block(
        inputs,
        filters,
        kernel_size=(3, 3),
        strides=(1, 1)
):
    x = tf.keras.layers.Conv2D(filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False)(
        inputs)
    tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU(6.0)(x)


def depthwise_conv_block(
        inputs,
        pointwise_conv_filters,
        strides=(1, 1)
):
    x = tf.keras.layers.DepthwiseConv2D((3, 3), padding='same', strides=strides, use_bias=False)(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU(6.0)(x)

    x = tf.keras.layers.Conv2D(pointwise_conv_filters, kernel_size=(1, 1), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)

    return tf.keras.layers.ReLU(6.0)(x)


def mobilenet_v1(
        inputs,
        classes
):
    x = conv_block(inputs, 32, strides=(2, 2))
    x = depthwise_conv_block(x, 64)
    x = depthwise_conv_block(x, 128, strides=(2, 2))
    x = depthwise_conv_block(x, 128)
    x = depthwise_conv_block(x, 256, strides=(2, 2))
    x = depthwise_conv_block(x, 256)
    x = depthwise_conv_block(x, 512, strides=(2, 2))
    x = depthwise_conv_block(x, 512)
    x = depthwise_conv_block(x, 512)
    x = depthwise_conv_block(x, 512)
    x = depthwise_conv_block(x, 512)
    x = depthwise_conv_block(x, 512)
    x = depthwise_conv_block(x, 1024, strides=(2, 2))
    x = depthwise_conv_block(x, 1024)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)

    x = tf.keras.layers.Dense(classes, activation='softmax')(x)

    return x


inputs = tf.keras.Input(shape=(32, 32, 3))
model = tf.keras.Model(inputs=inputs, outputs=mobilenet_v1(inputs, 10))
model.summary()
参数信息
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 16, 16, 32)        864       
_________________________________________________________________
re_lu_27 (ReLU)              (None, 16, 16, 32)        0         
_________________________________________________________________
depthwise_conv2d_13 (Depthwi (None, 16, 16, 32)        288       
_________________________________________________________________
batch_normalization_28 (Batc (None, 16, 16, 32)        128       
_________________________________________________________________
re_lu_28 (ReLU)              (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 16, 16, 64)        2048      
_________________________________________________________________
batch_normalization_29 (Batc (None, 16, 16, 64)        256       
_________________________________________________________________
re_lu_29 (ReLU)              (None, 16, 16, 64)        0         
_________________________________________________________________
depthwise_conv2d_14 (Depthwi (None, 8, 8, 64)          576       
_________________________________________________________________
batch_normalization_30 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
re_lu_30 (ReLU)              (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 8, 8, 128)         8192      
_________________________________________________________________
batch_normalization_31 (Batc (None, 8, 8, 128)         512       
_________________________________________________________________
re_lu_31 (ReLU)              (None, 8, 8, 128)         0         
_________________________________________________________________
depthwise_conv2d_15 (Depthwi (None, 8, 8, 128)         1152      
_________________________________________________________________
batch_normalization_32 (Batc (None, 8, 8, 128)         512       
_________________________________________________________________
re_lu_32 (ReLU)              (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 8, 8, 128)         16384     
_________________________________________________________________
batch_normalization_33 (Batc (None, 8, 8, 128)         512       
_________________________________________________________________
re_lu_33 (ReLU)              (None, 8, 8, 128)         0         
_________________________________________________________________
depthwise_conv2d_16 (Depthwi (None, 4, 4, 128)         1152      
_________________________________________________________________
batch_normalization_34 (Batc (None, 4, 4, 128)         512       
_________________________________________________________________
re_lu_34 (ReLU)              (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 4, 4, 256)         32768     
_________________________________________________________________
batch_normalization_35 (Batc (None, 4, 4, 256)         1024      
_________________________________________________________________
re_lu_35 (ReLU)              (None, 4, 4, 256)         0         
_________________________________________________________________
depthwise_conv2d_17 (Depthwi (None, 4, 4, 256)         2304      
_________________________________________________________________
batch_normalization_36 (Batc (None, 4, 4, 256)         1024      
_________________________________________________________________
re_lu_36 (ReLU)              (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 4, 4, 256)         65536     
_________________________________________________________________
batch_normalization_37 (Batc (None, 4, 4, 256)         1024      
_________________________________________________________________
re_lu_37 (ReLU)              (None, 4, 4, 256)         0         
_________________________________________________________________
depthwise_conv2d_18 (Depthwi (None, 2, 2, 256)         2304      
_________________________________________________________________
batch_normalization_38 (Batc (None, 2, 2, 256)         1024      
_________________________________________________________________
re_lu_38 (ReLU)              (None, 2, 2, 256)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 2, 2, 512)         131072    
_________________________________________________________________
batch_normalization_39 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_39 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
depthwise_conv2d_19 (Depthwi (None, 2, 2, 512)         4608      
_________________________________________________________________
batch_normalization_40 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_40 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
batch_normalization_41 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_41 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
depthwise_conv2d_20 (Depthwi (None, 2, 2, 512)         4608      
_________________________________________________________________
batch_normalization_42 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_42 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
batch_normalization_43 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_43 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
depthwise_conv2d_21 (Depthwi (None, 2, 2, 512)         4608      
_________________________________________________________________
batch_normalization_44 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_44 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
batch_normalization_45 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_45 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
depthwise_conv2d_22 (Depthwi (None, 2, 2, 512)         4608      
_________________________________________________________________
batch_normalization_46 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_46 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
batch_normalization_47 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_47 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
depthwise_conv2d_23 (Depthwi (None, 2, 2, 512)         4608      
_________________________________________________________________
batch_normalization_48 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_48 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
batch_normalization_49 (Batc (None, 2, 2, 512)         2048      
_________________________________________________________________
re_lu_49 (ReLU)              (None, 2, 2, 512)         0         
_________________________________________________________________
depthwise_conv2d_24 (Depthwi (None, 1, 1, 512)         4608      
_________________________________________________________________
batch_normalization_50 (Batc (None, 1, 1, 512)         2048      
_________________________________________________________________
re_lu_50 (ReLU)              (None, 1, 1, 512)         0         
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 1, 1, 1024)        524288    
_________________________________________________________________
batch_normalization_51 (Batc (None, 1, 1, 1024)        4096      
_________________________________________________________________
re_lu_51 (ReLU)              (None, 1, 1, 1024)        0         
_________________________________________________________________
depthwise_conv2d_25 (Depthwi (None, 1, 1, 1024)        9216      
_________________________________________________________________
batch_normalization_52 (Batc (None, 1, 1, 1024)        4096      
_________________________________________________________________
re_lu_52 (ReLU)              (None, 1, 1, 1024)        0         
_________________________________________________________________
conv2d_27 (Conv2D)           (None, 1, 1, 1024)        1048576   
_________________________________________________________________
batch_normalization_53 (Batc (None, 1, 1, 1024)        4096      
_________________________________________________________________
re_lu_53 (ReLU)              (None, 1, 1, 1024)        0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                10250     
=================================================================
Total params: 3,238,986
Trainable params: 3,217,162
Non-trainable params: 21,824
_________________________________________________________________
数据准备
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
模型训练
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['categorical_accuracy', 'Recall', 'AUC']
)

model.fit(x_train, y_train, batch_size=10, epochs=10)
model.save('mobilenet_v1_cifar10.h5')
模型应用
import cv2

img = cv2.imread('cat.png', 1) / 225.0
import numpy as np
img = np.expand_dims(img, 0)
pred = model.predict(img)

print(pred)

总结:

  1. Mobilenet V1 网络结构介绍及实现
  2. 模型训练及预测

作业

根据所讲内容独立完成本章提及的四个目标

推荐阅读