首页 > 解决方案 > 如何从keras中另一层的权重计算自定义过滤器?

问题描述

我正在尝试制作一个模型,以便以后可以进行量化意识培训。由于它共享权重,我使用了 keras 功能 API。我的模型代码如下:

def QCSNet(B, M):

    BB = B * B
    inputs = Input((B, B, 1), name='blocks')

    # equivalent to dividing the image into blocks and taking measurements for each block
    sample = Conv2D(filters=M, kernel_size=B, use_bias=False, name="take_measurements")
    output = sample(inputs)

    # Equivalent to the transpose of Phi, as convolutional weights
    PhiT = tf.transpose(tf.reshape(sample.weights[0], [1, 1, BB, M]), [0, 1, 3, 2])

    # initial reconstruction
    output = tf.nn.conv2d(output, filters=PhiT, strides=1, padding='SAME', name="init_rec1")
    output = tf.nn.depth_to_space(output, B, name="init_rec2")

    # Denoising
    output = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu',
                                     name="conv_init")(output)

    sharedConv1 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu',
                                     name="shared_conv1")
    sharedConv2 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu',
                                     name="shared_conv2")
    output = sharedConv1(output)
    output = sharedConv2(output)
    output = sharedConv1(output)
    output = sharedConv2(output)

    output = tf.keras.layers.Conv2D(filters=1, kernel_size=3, padding='same', name="output_conv")(output)

    return Model(inputs=inputs, outputs=output)

我正在重塑“样本”层的过滤器并在“init_rec1”卷积中使用它们:

# Equivalent to the transpose of Phi, as convolutional weights
PhiT = tf.transpose(tf.reshape(sample.weights[0], [1, 1, BB, M]), [0, 1, 3, 2])

# initial reconstruction
output = tf.nn.conv2d(output, filters=PhiT, strides=1, padding='SAME', name="init_rec1")

但是,tf.nn.conv2d()当我尝试量化它时,使用会给我带来问题。我在网上找到的另一种方法是使用tf.keras.layers.Conv2D如下:

x = layers.Conv2D(filters, kernel_size, kernel_initializer=my_filter, trainable=False)(input_tensor)

但据我了解,这将在开始时分配一次内核,而不根据sample.weights[0].

有什么办法可以在转发道具期间更新过滤器值?

标签: tensorflowkerasdeep-learningneural-networkconv-neural-network

解决方案


推荐阅读