首页 > 解决方案 > Keras 函数式 API 中的可重用块

问题描述

目标是使用 Keras 的功能 API 创建一个层块,该 API 可以像“普通”Keras 层一样使用(也是语法方面的)。这是一个玩具示例

from tensorflow.keras import layers as kl

def layer_block(prev_layer, args):
    # some code using 'args'
    layer = kl.Dense(units=prev_layer.shape[1])(prev_layer)
    layer = kl.Dense(units=5)(layer)
    layer = kl.Dense(units=prev_layer.shape[1])(layer)

    return layer

这个块被调用 usinglayer_block(prev_layer, args)与 Keras 的功能 API 的语法相矛盾。它应该看起来像layer_block(args)(prev_layer).

到目前为止的方法是用另一个块包装这个块:

def outer_block(args):
    def layer_block(prev_layer, args):
        # some code using 'args'
        layer = kl.Dense(units=prev_layer.shape[1])(prev_layer)
        layer = kl.Dense(units=5)(layer)
        layer = kl.Dense(units=prev_layer.shape[1])(layer)

        return layer
    return lambda prev_layer: layer_block(prev_layer, args)

现在出现两个问题:

  1. 有没有更简单的方法来实现这一点?
  2. 这种方式有效还是对性能有负面影响?

先感谢您!

标签: tensorflowkeraskeras-layer

解决方案


您所做的不会影响性能,您可以完美地创建图层。

您的两种方法中的任何一种都没有问题,但是如果您确实想让它作为实际层工作,请将其转换为模型。

这可能不适用于每个 keras 版本:

class LayerBlock(tensorflow.keras.Model): #not sure if it works in normal keras (without tf)

    def __init__(self):
        super(LayerBlock, self).__init__(outer_units)
        self.layer1 = kl.Dense(units=outer_units)
        self.layer2 = kl.Dense(units=5)
        self.layer3 = kl.Dense(units=outer_units)

    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

本教程似乎建议您可以使用tf.keras.Layer而不是tf.keras.Model,但这对我来说听起来很奇怪。它可以在急切模式下工作,但它缺少build带有self.built=True语句的方法。


推荐阅读