python - 我们能够修剪预训练的模型吗?示例:MobileNetV2
问题描述
我正在尝试修剪一个预先训练的模型:MobileNetV2,我得到了这个错误。尝试在网上搜索,无法理解。我在Google Colab上运行。
这些是我的进口。
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tensorflow_datasets as tfds
from tensorflow import keras
import os
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import zipfile
这是我的代码。
model_1 = keras.Sequential([
basemodel,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(1)
])
model_1.compile(optimizer='adam',
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model_1.fit(train_batches,
epochs=5,
validation_data=valid_batches)
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_2 = prune_low_magnitude(model_1, **pruning_params)
model_2.compile(optmizer='adam',
loss=keres.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
这是我得到的错误。
---> 12 model_2 = prune_low_magnitude(model, **pruning_params)
ValueError: Please initialize `Prune` with a supported layer. Layers should either be a `PrunableLayer` instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.training.Model'>
解决方案
我发现的一件事是我添加到模型中的实验性预处理引发了这个错误。我在我的模型开始时有这个来帮助添加更多的训练样本,但是 keras 修剪代码不喜欢这样的子类模型。同样,代码不喜欢像我对图像居中那样的实验性预处理。从模型中删除预处理为我解决了这个问题。
def classificationModel(trainImgs, testImgs):
L2_lambda = 0.01
data_augmentation = tf.keras.Sequential(
[ layers.experimental.preprocessing.RandomFlip("horizontal", input_shape=IM_DIMS),
layers.experimental.preprocessing.RandomRotation(0.1),
layers.experimental.preprocessing.RandomZoom(0.1),])
model = tf.keras.Sequential()
model.add(data_augmentation)
model.add(layers.experimental.preprocessing.Rescaling(1./255, input_shape=IM_DIMS))
...
推荐阅读
- amazon-web-services - 缓冲传入数据并放入 S3
- azure - 使用自定义脚本扩展从 Azure Key Vault 下载证书
- git - 寻求一种方法来克隆 github 项目的问题
- python - 无法使用 for 循环在 Vader Lexicon 中添加新单词。它可以在没有循环的情况下完美运行。我该如何解决这个问题?
- spring-data-mongodb - 何时使用 MongoOperations 和 MongoTemplate?
- ruby - 使用 youtube api,如何在 ruby 中上传私人视频?
- java - Kafka Stream 和 KGlobalTable Join 问题
- oracle - 在 Oracle 11g 中使用 ORACLE“FOR UPDATE SKIP LOCKED”选择非锁定行
- r - 为什么我们在向量中省略 NULL 值而不是用 0 替换?
- php - 如何访问 JSON 的某些部分?