首页 > 解决方案 > Tensorflow 模型修剪为训练和验证损失提供了“nan”

问题描述

我正在尝试修剪一个由 VGG 网络顶部的几个层组成的基本模型。它还包含一个名为 的用户定义层instance_normalization。为了剪枝成功,我定义了get_prunable_weights这一层的功能如下:

### defined for model pruning
    def get_prunable_weights(self):
        return self.weights

我使用以下函数使用名为 的基本模型获取要修剪的模型结构model

def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
        num_images = img_shape[0] * (1 - validation_split)
        end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

        # Define model for pruning.
        pruning_params = {
            'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                                    final_sparsity=0.80,
                                                                    begin_step=0,
                                                                    end_step=end_step)
        }

        model_for_pruning = prune_low_magnitude(model, **pruning_params)

        model_for_pruning.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])

        model_for_pruning.summary()

        return model_for_pruning

然后,我编写了以下函数来对这个剪枝模型进行训练:

def train_prune_model(self, model_for_pruning, train_images, train_labels,
                     epochs, batch_size, validation_split=0.1):
    callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
    ]
    model_for_pruning.fit(train_images, train_labels,
                batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                callbacks=callbacks)
    return model_for_pruning

然而,在训练的时候,我发现训练和验证的损失都是nan,最终的模型预测输出完全为零。但是,传递给的基本模型define_prune_model已成功训练并正确预测。

我该如何解决这个问题?先感谢您。

标签: pythontensorflowkerasdeep-learningmodel

解决方案


如果没有更多信息,很难确定问题所在。instance_normalization特别是,您能否提供有关您的自定义层的更多详细信息(最好是代码) ?

假设代码没问题:既然你提到模型在没有修剪的情况下正确训练,是不是那些修剪参数太苛刻了?毕竟,50%从第一个学习步骤开始,这些选项将权重设置为零。

这是我会尝试的:

  • 尝试使用较低级别的稀疏性(尤其是initial_sparsity)。
  • 稍后在训练期间开始应用剪枝(剪枝计划begin_step的参数)。有些人甚至更喜欢在不进行修剪的情况下训练一次模型。然后用 重新训练。prune_low_magnitude()
  • 仅在某些步骤进行修剪,为模型在修剪之间恢复提供时间(frequency参数)。
  • 最后如果它仍然失败,遇到 nan 损失时通常的治疗方法:降低学习率,使用正则化或梯度裁剪,...

推荐阅读