python - 具有多个输入的 TensorFlow 自定义层中构建方法的 Input_shape
问题描述
我必须设计一个接受两个输入X_1
和X_2
. 该层将它们转换为固定大小的向量(10D),然后按以下方式对它们求和
class my_lyr(tf.keras.layers.Layer):
def __init__(self):
pass
def call(self, X_1, X_2):
return X_1 @ self.w1 + X_2 @ self.w2
但是,在初始化 and 之前X_1
,我需要知道 and 的输入 形状。我不确定如何在.X_2
w1
w2
w2
build
def build(self, input_shape):
self.w1 = self.add_weight('w1', shape=[input_shape[-1],10])
// self.w2 = ?????
我想知道如何构建方法通常是在这种情况下编写的。
解决方案
如果你有两个这样的层输入,那么你可以简单地初始化你的权重,如下所示
import tensorflow as tf
from tensorflow import keras
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super(Linear, self).__init__()
self.units = units
def build(self, input_shape):
self.wa = self.add_weight(
shape=(input_shape[0][-1], self.units),
initializer="random_normal",
trainable=True,
)
self.wb = self.add_weight(
shape=(input_shape[1][-1], self.units),
initializer="random_normal",
trainable=True,
)
def call(self, inputs):
return tf.matmul(inputs[0], self.wa) + tf.matmul(inputs[1], self.wb)
传递输入
x = tf.random.normal(shape=(2,2))
linear_layer = Linear(32)
linear_layer([x, x])
<tf.Tensor: shape=(2, 32), dtype=float32, numpy=
array([[-0.08829461, -0.01605312, -0.04368614, -0.08116315, -0.01521384,
0.01132785, 0.10704445, -0.10873697, -0.0525714 , 0.07684848,
0.04586978, 0.01315852, 0.01369547, 0.07404792, 0.10313608,
-0.10851607, 0.04091477, -0.01723676, -0.0326797 , 0.03598418,
-0.11335816, -0.10044714, 0.13555384, 0.01689356, 0.02631954,
0.08226107, -0.08765724, -0.05981663, 0.00531629, 0.02930426,
0.04155847, 0.05339598],
[ 0.20617458, -0.05936547, 0.01735754, -0.06575315, 0.10090968,
-0.07796012, -0.1956767 , -0.03406558, 0.18604615, -0.03547171,
0.02784208, 0.0471364 , -0.10712875, -0.07869454, -0.19457275,
0.13593757, -0.14659101, 0.0384632 , 0.02344182, -0.03861775,
0.08948556, 0.09225713, -0.17395493, 0.10021958, -0.09210777,
-0.09865301, 0.2536609 , -0.02547608, 0.02885125, -0.01271547,
-0.10340843, -0.0338558 ]], dtype=float32)>
推荐阅读
- javascript - Javascript 上的数字根练习
- ios - 在 iOS 模拟器中运行 React Native 应用程序时找不到 UMModuleRegistryAdapter.h
- r - 通过天的分钟序列
- r - R绘图文本:如何向以指数形式呈现的多项式添加常规文本前缀(不是^)
- neo4j - 分叉模式的密码查询
- javascript - 在字节 00 处拆分十六进制字符串的优雅方法?
- elasticsearch - RabbitMQ Metricbeat 缺少队列
- c# - 使用 MSBUILD 复制文件
- c - 如何从 strtok_r() 中保存剩余的字符串?
- unit-testing - 如何在 bazel 测试规则中使用预编译的测试运行程序?