首页 > 解决方案 > 具有多个输入的 TensorFlow 自定义层中构建方法的 Input_shape

问题描述

我必须设计一个接受两个输入X_1X_2. 该层将它们转换为固定大小的向量(10D),然后按以下方式对它们求和

class my_lyr(tf.keras.layers.Layer):
    def __init__(self):
        pass
    def call(self, X_1, X_2):
        return X_1 @ self.w1 + X_2 @ self.w2  

但是,在初始化 and 之前X_1,我需要知道 and 的输入 形状。我不确定如何在.X_2w1w2w2build

def build(self, input_shape):
    self.w1 = self.add_weight('w1', shape=[input_shape[-1],10])
    // self.w2 = ?????

我想知道如何构建方法通常是在这种情况下编写的。

标签: pythontensorflowmachine-learningdeep-learningneural-network

解决方案


如果你有两个这样的层输入,那么你可以简单地初始化你的权重,如下所示

import tensorflow as tf 
from tensorflow import keras 

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.wa = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            initializer="random_normal",
            trainable=True,
        )

        self.wb = self.add_weight(
            shape=(input_shape[1][-1], self.units),
            initializer="random_normal",
            trainable=True,
        )

    def call(self, inputs):
        return tf.matmul(inputs[0], self.wa) + tf.matmul(inputs[1], self.wb)

传递输入

x = tf.random.normal(shape=(2,2))
linear_layer = Linear(32)
linear_layer([x, x])
<tf.Tensor: shape=(2, 32), dtype=float32, numpy=
array([[-0.08829461, -0.01605312, -0.04368614, -0.08116315, -0.01521384,
         0.01132785,  0.10704445, -0.10873697, -0.0525714 ,  0.07684848,
         0.04586978,  0.01315852,  0.01369547,  0.07404792,  0.10313608,
        -0.10851607,  0.04091477, -0.01723676, -0.0326797 ,  0.03598418,
        -0.11335816, -0.10044714,  0.13555384,  0.01689356,  0.02631954,
         0.08226107, -0.08765724, -0.05981663,  0.00531629,  0.02930426,
         0.04155847,  0.05339598],
       [ 0.20617458, -0.05936547,  0.01735754, -0.06575315,  0.10090968,
        -0.07796012, -0.1956767 , -0.03406558,  0.18604615, -0.03547171,
         0.02784208,  0.0471364 , -0.10712875, -0.07869454, -0.19457275,
         0.13593757, -0.14659101,  0.0384632 ,  0.02344182, -0.03861775,
         0.08948556,  0.09225713, -0.17395493,  0.10021958, -0.09210777,
        -0.09865301,  0.2536609 , -0.02547608,  0.02885125, -0.01271547,
        -0.10340843, -0.0338558 ]], dtype=float32)>

推荐阅读