tensorflow - 如何冻结/解冻预训练模型作为 Tensorflow 中子类模型的一部分?
问题描述
我正在尝试使用 Tensorflow >= 2.4 构建一个子类模型,该模型由预训练的卷积基础和顶部的一些密集层组成。然而,子类模型的冻结/解冻一旦之前被训练就没有任何效果。当我对功能 API 做同样的事情时,一切都按预期工作。我真的很感激我在这里缺少的一些提示:遵循代码应该进一步说明我的问题。请原谅我的代码量:
#Setup
import tensorflow as tf
tf.config.run_functions_eagerly(False)
import numpy as np
from tensorflow.keras.regularizers import l1
import matplotlib.pyplot as plt
@tf.function
def create_images_and_labels(img,label, height = 70, width = 70): #Image augmentation
label = tf.cast(label, 'float32')
label = tf.squeeze(label)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, (height, width))
# img = preprocess_input(img)
return img, label
cifar = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar.load_data()
num_classes = len(np.unique(y_train))
ds_train = tf.data.Dataset.from_tensor_slices((x_train, tf.one_hot(y_train, depth = len(np.unique(y_train)))))
ds_train = ds_train.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))
ds_train = ds_train.shuffle(50000)
ds_train = ds_train.batch(50, drop_remainder = True)
ds_val = tf.data.Dataset.from_tensor_slices((x_test, tf.one_hot(y_test, depth = len(np.unique(y_train)))))
ds_val = ds_val.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))
ds_val = ds_val.batch(50, drop_remainder=True)
# for i in ds_train.take(1):
# x, y = i
# for ind in range(x.shape[0]):
# plt.imshow(x[ind,:,:])
# plt.show()
# print(y[ind])
'''
Defining simple subclassed Model consisting of
VGG16
Flatten
Dense Layers
customized what happens in model.fit and model.evaluate (Actually its the standard Keras procedure with custom Metrics)
customized metrics: Loss and Accuracy for Training and Validation Step
added unfreezing Method
'set_trainable_layers'
Arguments:
num_head (How many dense Layers)
num_base (How many VGG Layers)
'''
class Test_Model(tf.keras.models.Model):
def __init__(
self,
num_unfrozen_head_layers,
num_unfrozen_base_layers,
num_classes,
conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3)),
):
super(Test_Model, self).__init__(name = "Test_Model")
self.base = conv_base
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(2048, activation = 'relu')
self.dense2 = tf.keras.layers.Dense(1024, activation = 'relu')
self.dense3 = tf.keras.layers.Dense(128, activation = 'relu')
self.out = tf.keras.layers.Dense(num_classes, activation = 'softmax')
self.out._name = 'out'
self.train_loss_metric = tf.keras.metrics.Mean('Supervised Training Loss')
self.train_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Training Accuracy')
self.val_loss_metric = tf.keras.metrics.Mean('Supervised Validation Loss')
self.val_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Validation Accuracy')
self.loss_fn = tf.keras.losses.categorical_crossentropy
self.learning_rate = 1e-4
# self.build((None, 32,32,3))
self.set_trainable_layers(num_unfrozen_head_layers, num_unfrozen_base_layers)
@tf.function
def call(self, inputs, training = False):
x = self.base(inputs)
x = self.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
x = self.out(x)
return x
@tf.function
def train_step(self, input_data):
x_batch, y_batch = input_data
with tf.GradientTape() as tape:
tape.watch(x_batch)
y_pred = self(x_batch, training = True)
loss = self.loss_fn(y_batch, y_pred)
trainable_vars = self.trainable_weights
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.train_loss_metric.update_state(loss)
self.train_acc_metric.update_state(y_batch, y_pred)
return {"Supervised Loss": self.train_loss_metric.result(),
"Supervised Accuracy":self.train_acc_metric.result()}
@tf.function
def test_step(self, input_data):
x_batch,y_batch = input_data
y_pred = self(x_batch, training = False)
loss = self.loss_fn(y_batch, y_pred)
self.val_loss_metric.update_state(loss)
self.val_acc_metric.update_state(y_batch, y_pred)
return {"Val Supervised Loss": self.val_loss_metric.result(),
"Val Supervised Accuracy":self.val_acc_metric.result()}
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [self.train_loss_metric,
self.train_acc_metric,
self.val_loss_metric,
self.val_acc_metric]
def set_trainable_layers(self, num_head, num_base):
for layer in [lay for lay in self.layers if not isinstance(lay , tf.keras.models.Model)]:
layer.trainable = False
print(layer.name, layer.trainable)
for block in self.layers:
if isinstance(block, tf.keras.models.Model):
print('Found Submodel', block.name)
for layer in block.layers:
layer.trainable = False
print(layer.name, layer.trainable)
if num_base > 0:
for layer in block.layers[-num_base:]:
layer.trainable = True
print(layer.name, layer.trainable)
if num_head > 0:
for layer in [lay for lay in self.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
layer.trainable = True
print(layer.name, layer.trainable)
'''
Showcase1: First training completely frozen Model, then unfreezing:
unfreezed model doesnt learn
'''
model = Test_Model(num_unfrozen_head_layers= 0, num_unfrozen_base_layers = 0, num_classes = num_classes) # Should NOT learn -> doesnt learn
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
model.set_trainable_layers(10,20) # SHOULD LEARN -> Doesnt learn
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
#DOESNT LEARN
'''
Showcase2: when first training the Model with more trainable Layers than in the second step:
AssertionError occurs
'''
model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 2, num_classes = num_classes) # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
model.set_trainable_layers(1,1) # SHOULD NOT LEARN -> AssertionError
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
'''
Showcase3: same Procedure as in Showcase2 but optimizer State is transferred to recompiled Model:
Cant set Weigthts because optimizer expects List of Length 0
'''
model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 20, num_classes = num_classes) # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
opti_state = model.optimizer.get_weights()
model.set_trainable_layers(0,0) # SHOULD NOT LEARN -> Learns
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.optimizer.set_weights(opti_state)
model.fit(ds_train, validation_data = ds_val)
#%%%
'''
Constructing same Architecture with Functional API and running Experiments
'''
import tensorflow as tf
conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3))
inputs = tf.keras.layers.Input((70,70,3))
x = conv_base(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(2048, activation = 'relu') (x)
x = tf.keras.layers.Dense(1024,activation = 'relu') (x)
x = tf.keras.layers.Dense(128,activation = 'relu') (x)
out = tf.keras.layers.Dense(num_classes,activation = 'softmax') (x)
isinstance(tf.keras.layers.Flatten(), tf.keras.models.Model)
isinstance(conv_base, tf.keras.models.Model)
def set_trainable_layers(mod, num_head, num_base):
import time
for layer in [lay for lay in mod.layers if not isinstance(lay , tf.keras.models.Model)]:
layer.trainable = False
print(layer.name, layer.trainable)
for block in mod.layers:
if isinstance(block, tf.keras.models.Model):
print('Found Submodel')
for layer in block.layers:
layer.trainable = False
print(layer.name, layer.trainable)
if num_base > 0:
for layer in block.layers[-num_base:]:
layer.trainable = True
print(layer.name, layer.trainable)
if num_head > 0:
for layer in [lay for lay in mod.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
layer.trainable = True
print(layer.name, layer.trainable)
'''
Showcase1: First training frozen Model, then unfreezing, recomiling and retraining:
model behaves as expected
'''
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')
set_trainable_layers(mod, 0 ,0)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model should NOT learn
set_trainable_layers(mod, 10,20)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model SHOULD learn
'''
Showcase2: First training unfrozen Model, then reducing number of trainable Layers:
Model behaves as Expected
'''
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')
set_trainable_layers(mod, 10 ,20)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn
set_trainable_layers(mod, 0,0)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn
'''
Showcase3: First training unfrozen Model, then reducing number of trainable Layers but also trying to trasnfer Optimizer States:
Behaves as subclassed Model: New Optimizer shouldnt have Weights
'''
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')
set_trainable_layers(mod, 1 ,3)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn
opti_state = mod.optimizer.get_weights()
set_trainable_layers(mod, 4,8)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.optimizer.set_weights(opti_state)
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn
解决方案
之所以发生这种情况,是因为 Tensorflow2 中的子类化 API 与函数式或顺序式 API 之间的根本区别之一。
功能或顺序 API 构建层图(将其视为单独的数据结构),子类模型构建整个对象并将其存储为字节码。
这意味着使用子类化您将无法访问内部连接图,并且允许您冻结/解冻层或在其他模型中重用它们的正常行为开始变得奇怪。看到您的实现,我会说 Subclassed 模型是正确的,如果我们处理的是 Tensorflow 以外的库,它应该可以工作。
Francois Chollet 比我在他的一篇推文中解释得更好
推荐阅读
- javascript - 如何刷新 github 个人访问令牌
- apache-spark-sql - Spark SQL从字符串字段返回数据保存json字典样式
- amazon-web-services - 如何销售自定义 AWS AMI
- c# - 无法将工作服务作为 Windows 服务启动
- c# - 从与 AsyncCrudAppService 的关系中获取数据
- sapui5 - SAP Web IDE 警告:“未定义控件 ID。请输入唯一 ID。”
- google-api - Google Analytics API - 使用 t-sql 从 sql server 进行身份验证
- flutter - 无效参数:在 URI 文件中未指定主机:///assets/Upload/Item/214b5d5c-ca86-45f5-bb45-7850559a23bb.jpg
- pyspark - PySpark如何根据行值创建列
- mysql - 用于存储 MySQL 的多个 BOOLEAN 或一个 BIGINT