首页 > 解决方案 > 如何冻结/解冻预训练模型作为 Tensorflow 中子类模型的一部分?

问题描述

我正在尝试使用 Tensorflow >= 2.4 构建一个子类模型,该模型由预训练的卷积基础和顶部的一些密集层组成。然而,子类模型的冻结/解冻一旦之前被训练就没有任何效果。当我对功能 API 做同样的事情时,一切都按预期工作。我真的很感激我在这里缺少的一些提示:遵循代码应该进一步说明我的问题。请原谅我的代码量:

#Setup


import tensorflow as tf
tf.config.run_functions_eagerly(False)
 
import numpy as np
from tensorflow.keras.regularizers import l1 
import matplotlib.pyplot as plt


@tf.function
def create_images_and_labels(img,label, height = 70, width = 70): #Image augmentation

    label = tf.cast(label, 'float32')
    label = tf.squeeze(label)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, (height, width))
  #  img = preprocess_input(img)
    
    return img, label



cifar = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test)  = cifar.load_data()
num_classes = len(np.unique(y_train))

ds_train = tf.data.Dataset.from_tensor_slices((x_train, tf.one_hot(y_train, depth = len(np.unique(y_train)))))
ds_train = ds_train.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))

ds_train = ds_train.shuffle(50000)
ds_train = ds_train.batch(50, drop_remainder = True)


ds_val = tf.data.Dataset.from_tensor_slices((x_test, tf.one_hot(y_test, depth = len(np.unique(y_train)))))
ds_val = ds_val.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))
ds_val = ds_val.batch(50, drop_remainder=True)



# for i in ds_train.take(1):
#     x, y = i
#     for ind in range(x.shape[0]):
#         plt.imshow(x[ind,:,:])
#         plt.show()
#         print(y[ind])


'''
Defining simple subclassed Model consisting of 
VGG16
Flatten
Dense Layers

customized what happens in model.fit and model.evaluate (Actually its the standard Keras procedure with custom Metrics)
customized metrics: Loss and Accuracy for Training and Validation Step

added unfreezing Method 
'set_trainable_layers'
Arguments: 
    num_head (How many dense Layers)
    num_base (How many VGG Layers)
'''



class Test_Model(tf.keras.models.Model):
    
    def __init__(
            self,
            num_unfrozen_head_layers, 
            num_unfrozen_base_layers,
            num_classes,
            conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3)),
            
            
            
            ):
                super(Test_Model, self).__init__(name = "Test_Model")
                
                self.base = conv_base
                self.flatten = tf.keras.layers.Flatten()
                self.dense1 = tf.keras.layers.Dense(2048, activation = 'relu')
                self.dense2 = tf.keras.layers.Dense(1024, activation = 'relu')
                self.dense3 = tf.keras.layers.Dense(128, activation = 'relu')
                self.out = tf.keras.layers.Dense(num_classes, activation = 'softmax')
                self.out._name = 'out'



                    
                self.train_loss_metric = tf.keras.metrics.Mean('Supervised Training Loss')
                self.train_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Training Accuracy')
                self.val_loss_metric = tf.keras.metrics.Mean('Supervised Validation Loss')
                self.val_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Validation Accuracy')
                self.loss_fn = tf.keras.losses.categorical_crossentropy
                self.learning_rate = 1e-4
                
              #  self.build((None, 32,32,3))
                self.set_trainable_layers(num_unfrozen_head_layers, num_unfrozen_base_layers)
        
                
    @tf.function
    def call(self, inputs, training = False):
        x = self.base(inputs)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.out(x)
        return x
    @tf.function
    def train_step(self, input_data):
         x_batch, y_batch = input_data
         with tf.GradientTape() as tape: 
                tape.watch(x_batch)
                y_pred = self(x_batch, training = True)
                loss = self.loss_fn(y_batch, y_pred)
                
         trainable_vars = self.trainable_weights
         gradients = tape.gradient(loss, trainable_vars)
            
         self.optimizer.apply_gradients(zip(gradients, trainable_vars))
         self.train_loss_metric.update_state(loss)
         self.train_acc_metric.update_state(y_batch, y_pred)
            
         return {"Supervised Loss": self.train_loss_metric.result(),
                 "Supervised Accuracy":self.train_acc_metric.result()}
     
    @tf.function
    def test_step(self, input_data):
        x_batch,y_batch = input_data
        y_pred = self(x_batch, training = False)
        loss = self.loss_fn(y_batch, y_pred)
        
        self.val_loss_metric.update_state(loss)
        self.val_acc_metric.update_state(y_batch, y_pred)
            
        return {"Val Supervised Loss": self.val_loss_metric.result(),
                "Val Supervised Accuracy":self.val_acc_metric.result()}
    
    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [self.train_loss_metric,
                self.train_acc_metric,
                self.val_loss_metric,
                self.val_acc_metric]  
    
    def set_trainable_layers(self, num_head, num_base):
        
        for layer in [lay for lay in self.layers if not isinstance(lay , tf.keras.models.Model)]: 
            layer.trainable = False
            print(layer.name, layer.trainable)
        for block in self.layers:
            
            if isinstance(block, tf.keras.models.Model):
                print('Found Submodel', block.name)
                for layer in block.layers: 
                    layer.trainable = False
                    print(layer.name, layer.trainable)
                if num_base > 0:    
                    for layer in block.layers[-num_base:]:
                        layer.trainable = True
                        print(layer.name, layer.trainable)
        if num_head > 0: 
            for layer in [lay for lay in self.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
                layer.trainable = True 
                print(layer.name, layer.trainable)
        
        
    
'''
Showcase1: First training completely frozen Model, then unfreezing: 
    unfreezed model doesnt learn
'''    
    
model = Test_Model(num_unfrozen_head_layers= 0, num_unfrozen_base_layers = 0, num_classes = num_classes)    # Should NOT learn -> doesnt learn
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)


model.set_trainable_layers(10,20) # SHOULD LEARN -> Doesnt learn
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
#DOESNT LEARN
    
'''
Showcase2: when first training the Model with more trainable Layers than in the second step:
    AssertionError occurs
'''
model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 2, num_classes = num_classes)    # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)


model.set_trainable_layers(1,1) # SHOULD NOT LEARN -> AssertionError
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)

'''
Showcase3: same Procedure as in Showcase2 but optimizer State is transferred to recompiled Model:
    Cant set Weigthts because optimizer expects List of Length 0
    
'''

model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 20, num_classes = num_classes)    # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)

opti_state = model.optimizer.get_weights()
model.set_trainable_layers(0,0) # SHOULD NOT LEARN -> Learns
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.optimizer.set_weights(opti_state)
model.fit(ds_train, validation_data = ds_val)
    


#%%%    
'''
Constructing same Architecture with Functional API and running Experiments
'''

import tensorflow as tf    
conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3))
    
inputs = tf.keras.layers.Input((70,70,3))
x = conv_base(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(2048, activation = 'relu') (x)
x = tf.keras.layers.Dense(1024,activation = 'relu') (x)
x = tf.keras.layers.Dense(128,activation = 'relu') (x)
out = tf.keras.layers.Dense(num_classes,activation = 'softmax') (x)
    


isinstance(tf.keras.layers.Flatten(), tf.keras.models.Model)
isinstance(conv_base, tf.keras.models.Model)


def set_trainable_layers(mod, num_head, num_base):
    import time
    for layer in [lay for lay in mod.layers if not isinstance(lay , tf.keras.models.Model)]: 
        layer.trainable = False
        print(layer.name, layer.trainable)
    for block in mod.layers:
        
        if isinstance(block, tf.keras.models.Model):
            print('Found Submodel')
            for layer in block.layers: 
                layer.trainable = False
                print(layer.name, layer.trainable)
            if num_base > 0:    
                for layer in block.layers[-num_base:]:
                    layer.trainable = True
                    print(layer.name, layer.trainable)
    if num_head > 0: 
        for layer in [lay for lay in mod.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
            layer.trainable = True 
            print(layer.name, layer.trainable)
       
    
    
'''
Showcase1: First training frozen Model, then unfreezing, recomiling and retraining:
    model behaves as expected
'''    
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')    
set_trainable_layers(mod, 0 ,0)    
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model should NOT learn


set_trainable_layers(mod, 10,20)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model SHOULD learn


'''
Showcase2: First training unfrozen Model, then reducing number of trainable Layers:
    Model behaves as Expected
'''


mod = tf.keras.models.Model(inputs,out, name = 'TestModel')    
set_trainable_layers(mod, 10 ,20)    
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn


set_trainable_layers(mod, 0,0)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn




'''
Showcase3: First training unfrozen Model, then reducing number of trainable Layers but also trying to trasnfer Optimizer States:
    Behaves as subclassed Model: New Optimizer shouldnt have Weights
'''


mod = tf.keras.models.Model(inputs,out, name = 'TestModel')    
set_trainable_layers(mod, 1 ,3)    
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn

opti_state = mod.optimizer.get_weights()
set_trainable_layers(mod, 4,8)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.optimizer.set_weights(opti_state)
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn

标签: tensorflowkerasdeep-learning

解决方案


之所以发生这种情况,是因为 Tensorflow2 中的子类化 API 与函数式或顺序式 API 之间的根本区别之一。

功能或顺序 API 构建层图(将其视为单独的数据结构),子类模型构建整个对象并将其存储为字节码。

这意味着使用子类化您将无法访问内部连接图,并且允许您冻结/解冻层或在其他模型中重用它们的正常行为开始变得奇怪。看到您的实现,我会说 Subclassed 模型是正确的,如果我们处理的是 Tensorflow 以外的库,它应该可以工作。

Francois Chollet 比我在他的一篇推文中解释得更好


推荐阅读