首页 > 解决方案 > 张量流中梯度累积的trainable_variables不一致

问题描述

我一直在研究使用 tensorflow 进行文本分类的 bert 模型。我想实现一个梯度累积机制,这样我就可以获得更大的批次大小并在 n 个批次后更新权重。为此,我使用了这个(梯度累积与 TF.Keras 中的自定义 model.fit?),但我将它与预训练的 bert 模型一起使用,所以我想用它进行迁移学习。

该问题与 trainable_variables 有关。在我将其中一个模型层设置为不可训练的那一刻,trainable_variables 返回的是一个空列表。鉴于此,gradient_accumulation 向量为空,并且权重永远不会更新。

如何获得具有真实可训练变量的 trainable_variables,以便正确应用梯度?

import tensorflow as tf
from transformers import BertConfig,TFBertModel,BertTokenizerFast
from tensorflow.keras.layers import  Dense
import tensorflow_addons as tfa
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.initializers import TruncatedNormal


class BetoClassifier(tf.keras.Model):

    def __init__(self,tag_len,n_gradients,num_layers, *args, **kwargs):
        super(BetoClassifier,self).__init__()
        model_name = 'dccuchile/bert-base-spanish-wwm-uncased'
        
        # Max length of tokens
        # Load transformers config and set output_hidden_states to False
        self.config = BertConfig.from_pretrained(model_name)
        print(self.config)
        self.config.output_hidden_states = False

        # Load BERT tokenizer
        self.tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path = model_name, config = self.config)

        # Load the Transformers BERT model
        transformer_model = TFBertModel.from_pretrained(model_name,  from_pt=True,config=self.config)
        #for layer in transformer_model.layers:
        #    layer.trainable = False
        
        self.bert = transformer_model.layers[0]        
        self.lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1),trainable=True)
        self.dense1 = tf.keras.layers.Dense(50, activation='relu')        
        self.tag = Dense(units=tag_len, kernel_initializer=TruncatedNormal(stddev=self.config.initializer_range), name='tag')
        
        #gradient accumulation
        self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
        self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
        self.gradient_accumulation = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) for v in self.trainable_variables]
        #self.bert.trainable = False
        
        # And combine it all in a model object        

    def call(self,inputs):
        bert_model = self.bert(inputs['input_ids'])[0]
        X = self.lstm(bert_model)
        X = tf.keras.layers.GlobalMaxPool1D()(X)
        X = self.dense1(X)
        dropout = tf.keras.layers.Dropout(0.2)
        X = dropout(X)
        detalle = self.tag(X)

        return {'tag':detalle}

    def train_step(self, data):
        self.n_acum_step.assign_add(1)

        x, y = data
        # Gradient Tape
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        # Calculate batch gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        # Accumulate batch gradients
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign_add(gradients[i])
 
        # If n_acum_step reach the n_gradients then we apply accumulated gradients to update the variables otherwise do nothing
        tf.cond(tf.equal(self.n_acum_step, self.n_gradients), self.apply_accu_gradients, lambda: None)

        # update metrics
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def apply_accu_gradients(self):
        # apply accumulated gradients
        self.optimizer.apply_gradients(zip(self.gradient_accumulation, self.trainable_variables))

        # reset
        self.n_acum_step.assign(0)
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign(tf.zeros_like(self.trainable_variables[i], dtype=tf.float32))

标签: tensorflowkerasdeep-learninghuggingface-transformersbert-language-model

解决方案


推荐阅读