首页 > 解决方案 > TensorFlow CNN 过滤器修剪

问题描述

我想在 tensorflow 中对 CNN 过滤器进行修剪。

最初,我使用了 kerassurgeon,但该库似乎不再维护,并且无法正确使用任何 TF 版本 > 2.xx

因此,我尝试实现一个修剪 CNN 过滤器的 POC 代码。

工作流程很简单:

def prune(model, ...)
    # here is some logic to get the correct weight idx from layer index and channel index
    # ...

    current_weights = current_model.weights

    # 01 remove filter from weights
    weights_vec = tf.unstack(current_weights[weights_idx], axis=-1)
    del weights_vec[channel_idx]
    next_weights = tf.stack(weights_vec, -1)

    # 02 remove filter from bias
    bias_vec = tf.unstack(current_weights[bias_idx])
    del bias_vec[channel_idx]
    next_bias = tf.stack(bias_vec)

    # new tensor
    new_weights = current_weights.copy()
    new_weights[weights_idx] = next_weights
    new_weights[bias_idx] = next_bias

    # 03 Change input for next layer
    weights_next_layer_idx = 2 * (layer_weight_index+1)
    weights_vec = tf.unstack(current_weights[weights_next_layer_idx], axis=-2)
    del weights_vec[channel_idx]
    new_weights[weights_next_layer_idx] = tf.stack(weights_vec, -2)

    # To Numpy
    new_weights_numpy = []
    for weight in new_weights:
        new_weights_numpy.append(weight.numpy())

    # 04
    new_model = create_model_from_config(next_config)
    new_model.build(input_shape)
    new_model.set_weights(new_weights_numpy)

代码运行没有任何大问题。但我有一种感觉,我错过了一些东西,因为验证损失受到的影响非常大。现在,L1 最低的过滤器被修剪,因此损失应该不会受到太大影响。

使用自己的 poc 代码的损失(黄色:val loss,蓝色 train loss):

self_written_lib

只要只使用 Conv2D 层,Kerasurgeon 就可以工作。因此,本次测试中的网络仅包含 Conv2D 和 Relu。使用 kerassurgeon 的验证损失与在相同验证数据上的 poc 代码中具有相同的预训练模型权重:

在此处输入图像描述

kerassurgeon 库仍将 val 损失保持在 0.0006 左右。

为了进一步测试,我使用 poc 和 kerassurgeon 库修剪了一些过滤器。两者都产生相同的权重(没有任何训练只是修剪权重)。

# delete using kerassurgeon
layer = model.layers[0]
pruned_kerassurgeon = delete_channels(model, layer, [3])

# delete using poc
pruned_poc = prune(model, 0, 3)

for i in range(len(pruned_kerassurgeon.weigths)):
    pruned_kerassurgeon.weigths[i] - pruned_poc.weights[i]

# results in 0 for every weight (therefore there is no difference in pruning between poc and kerassurgeon)

在 tensorflow 中没有过滤器的情况下,在新模型上设置新权重还不够吗?我错过了什么吗?

标签: pythontensorflowconv-neural-networktensorflow2.0pruning

解决方案


推荐阅读