首页 > 解决方案 > 我们能够修剪预训练的模型吗?示例:MobileNetV2

问题描述

我正在尝试修剪一个预先训练的模型:MobileNetV2,我得到了这个错误。尝试在网上搜索,无法理解。我在Google Colab上运行。

这些是我的进口。

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tensorflow_datasets as tfds
from tensorflow import keras

import os
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import zipfile

这是我的代码。

model_1 = keras.Sequential([
    basemodel,
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(1)                            
])

model_1.compile(optimizer='adam',
                loss=keras.losses.BinaryCrossentropy(from_logits=True),
                metrics=['accuracy'])

model_1.fit(train_batches,
            epochs=5,
            validation_data=valid_batches)

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}


model_2 = prune_low_magnitude(model_1, **pruning_params)

model_2.compile(optmizer='adam',
                loss=keres.losses.BinaryCrossentropy(from_logits=True),
                metrics=['accuracy'])

这是我得到的错误。

---> 12 model_2 = prune_low_magnitude(model, **pruning_params)

ValueError: Please initialize `Prune` with a supported layer. Layers should either be a `PrunableLayer` instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.training.Model'>

标签: pythontensorflowkeraspruning

解决方案


我发现的一件事是我添加到模型中的实验性预处理引发了这个错误。我在我的模型开始时有这个来帮助添加更多的训练样本,但是 keras 修剪代码不喜欢这样的子类模型。同样,代码不喜欢像我对图像居中那样的实验性预处理。从模型中删除预处理为我解决了这个问题。

def classificationModel(trainImgs, testImgs):
  L2_lambda = 0.01
  data_augmentation = tf.keras.Sequential(
  [ layers.experimental.preprocessing.RandomFlip("horizontal", input_shape=IM_DIMS),
    layers.experimental.preprocessing.RandomRotation(0.1),
    layers.experimental.preprocessing.RandomZoom(0.1),])

  model = tf.keras.Sequential()
  model.add(data_augmentation)
  model.add(layers.experimental.preprocessing.Rescaling(1./255, input_shape=IM_DIMS))
...

推荐阅读