python - Tensorflow 模型修剪为训练和验证损失提供了“nan”
问题描述
我正在尝试修剪一个由 VGG 网络顶部的几个层组成的基本模型。它还包含一个名为 的用户定义层instance_normalization
。为了剪枝成功,我定义了get_prunable_weights
这一层的功能如下:
### defined for model pruning
def get_prunable_weights(self):
return self.weights
我使用以下函数使用名为 的基本模型获取要修剪的模型结构model
:
def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
num_images = img_shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_pruning.summary()
return model_for_pruning
然后,我编写了以下函数来对这个剪枝模型进行训练:
def train_prune_model(self, model_for_pruning, train_images, train_labels,
epochs, batch_size, validation_split=0.1):
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
]
model_for_pruning.fit(train_images, train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
return model_for_pruning
然而,在训练的时候,我发现训练和验证的损失都是nan
,最终的模型预测输出完全为零。但是,传递给的基本模型define_prune_model
已成功训练并正确预测。
我该如何解决这个问题?先感谢您。
解决方案
如果没有更多信息,很难确定问题所在。instance_normalization
特别是,您能否提供有关您的自定义层的更多详细信息(最好是代码) ?
假设代码没问题:既然你提到模型在没有修剪的情况下正确训练,是不是那些修剪参数太苛刻了?毕竟,50%
从第一个学习步骤开始,这些选项将权重设置为零。
这是我会尝试的:
- 尝试使用较低级别的稀疏性(尤其是
initial_sparsity
)。 - 稍后在训练期间开始应用剪枝(剪枝计划
begin_step
的参数)。有些人甚至更喜欢在不进行修剪的情况下训练一次模型。然后用 重新训练。prune_low_magnitude()
- 仅在某些步骤进行修剪,为模型在修剪之间恢复提供时间(
frequency
参数)。 - 最后如果它仍然失败,遇到 nan 损失时通常的治疗方法:降低学习率,使用正则化或梯度裁剪,...
推荐阅读
- c# - 如何在 Angular 上创建下载文件服务?
- azure - 使用 powershell 覆盖 ftp 文件
- php - 当我编辑然后更新它时如何按 id 显示选择选项?
- android - packageInfo.requestedPermissions 返回不正确的权限
- mysql - 基于联接从多个表中检索数据时出错
- python - 如何从python传递kwargs以更新sql中的多个列
- firebase - 如何为 Firebase Cloud 消息添加 Google API 控制台创建的 API 密钥
- r - 以 png 格式导出使用 rworldmap 创建的地图
- c# - 选择包含部分关键字的引号之间的所有内容
- javascript - 如何将对象类型的 json 数组传递给 MVC 控制器类