python - TensorFlow CNN 过滤器修剪
问题描述
我想在 tensorflow 中对 CNN 过滤器进行修剪。
最初,我使用了 kerassurgeon,但该库似乎不再维护,并且无法正确使用任何 TF 版本 > 2.xx
因此,我尝试实现一个修剪 CNN 过滤器的 POC 代码。
工作流程很简单:
- 过滤器的排名权重并根据排名选择要删除的层/过滤器
- 01:从权重中移除过滤器(过滤器被修剪的层)
- 02:从偏差中移除过滤器(过滤器被修剪的层)
- 03:从权重中删除过滤器(下一层)
- 04:创建基于旧模型的新模型,没有经过修剪的过滤器并设置权重
def prune(model, ...)
# here is some logic to get the correct weight idx from layer index and channel index
# ...
current_weights = current_model.weights
# 01 remove filter from weights
weights_vec = tf.unstack(current_weights[weights_idx], axis=-1)
del weights_vec[channel_idx]
next_weights = tf.stack(weights_vec, -1)
# 02 remove filter from bias
bias_vec = tf.unstack(current_weights[bias_idx])
del bias_vec[channel_idx]
next_bias = tf.stack(bias_vec)
# new tensor
new_weights = current_weights.copy()
new_weights[weights_idx] = next_weights
new_weights[bias_idx] = next_bias
# 03 Change input for next layer
weights_next_layer_idx = 2 * (layer_weight_index+1)
weights_vec = tf.unstack(current_weights[weights_next_layer_idx], axis=-2)
del weights_vec[channel_idx]
new_weights[weights_next_layer_idx] = tf.stack(weights_vec, -2)
# To Numpy
new_weights_numpy = []
for weight in new_weights:
new_weights_numpy.append(weight.numpy())
# 04
new_model = create_model_from_config(next_config)
new_model.build(input_shape)
new_model.set_weights(new_weights_numpy)
代码运行没有任何大问题。但我有一种感觉,我错过了一些东西,因为验证损失受到的影响非常大。现在,L1 最低的过滤器被修剪,因此损失应该不会受到太大影响。
使用自己的 poc 代码的损失(黄色:val loss,蓝色 train loss):
只要只使用 Conv2D 层,Kerasurgeon 就可以工作。因此,本次测试中的网络仅包含 Conv2D 和 Relu。使用 kerassurgeon 的验证损失与在相同验证数据上的 poc 代码中具有相同的预训练模型权重:
kerassurgeon 库仍将 val 损失保持在 0.0006 左右。
为了进一步测试,我使用 poc 和 kerassurgeon 库修剪了一些过滤器。两者都产生相同的权重(没有任何训练只是修剪权重)。
# delete using kerassurgeon
layer = model.layers[0]
pruned_kerassurgeon = delete_channels(model, layer, [3])
# delete using poc
pruned_poc = prune(model, 0, 3)
for i in range(len(pruned_kerassurgeon.weigths)):
pruned_kerassurgeon.weigths[i] - pruned_poc.weights[i]
# results in 0 for every weight (therefore there is no difference in pruning between poc and kerassurgeon)
在 tensorflow 中没有过滤器的情况下,在新模型上设置新权重还不够吗?我错过了什么吗?
解决方案
推荐阅读
- sql-server - 可空表达式上的 Azure Synapse 军事化视图错误
- sql - 仅当两个用户都是 sql server 的朋友时如何获取数据
- java - Java实现互相关二维图像
- python - 安装 ddsp python 模块时出现问题 9ModuleNotFoundError: No module named 'ddsp')
- javascript - 如何使用 WebView2 获取回调数据
- sql - 使用 Hive SQL 获取文件系统目录大小
- shopify - 在 shopify 上完成结帐时保存其他数据
- css - 我需要帮助在标题中的页面链接后面添加图像
- java - Jenkins - 正则表达式匹配/不匹配来自 postbuild 日志中的单词不起作用
- java - LoadingCache 不适用于 CompletionStage