tensorflow - 自定义 keras 层未显示在 model.summary 中
问题描述
当我尝试.summary
使用自定义图层查看模型时,我得到以下输出:
Model: "functional_29"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_118 (InputNode) [(None, 1)] 0
__________________________________________________________________________________________________
input_119 (InputNode) [(None, 1)] 0
__________________________________________________________________________________________________
tf_op_layer_strided_slice_156 ( [(1,)] 0 input_118[0][0]
__________________________________________________________________________________________________
tf_op_layer_strided_slice_157 ( [(1,)] 0 input_119[0][0]
__________________________________________________________________________________________________
input_120 (InputNode) [(None, 1)] 0
_________________________________________________________________________________________________
tf_op_layer_concat_106 (TensorF [(2,)] 0 tf_op_layer_strided_slice_162[0][
tf_op_layer_strided_slice_163[0][
...
__________________________________________________________________________________________________
tf_op_layer_strided_slice_164 ( [(1,)] 0 input_120[0][0]
__________________________________________________________________________________________________
tf_op_layer_node_128_output (Te [()] 0 tf_op_layer_Relu_55[0][0]
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________
这是为什么?如何将所有这些操作包装在标签下MyLayer
?
解决方案
您可以通过从tf.keras.layers.Layer
.
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input
class DirectLayer(tf.keras.layers.Layer):
def __init__(self, name = "direct_layer", **kwargs):
super(DirectLayer, self).__init__(name=name, **kwargs)
def build(self, input_shape):
self.w = self.add_weight(
shape=input_shape[1:],
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=input_shape[1:], initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.multiply(inputs, self.w) + self.b
x_in = Input(shape=[10])
x = DirectLayer(name="my_layer")(x_in)
x_out = DirectLayer()(x)
model = Model(x_in, x_out)
x = tf.ones([16,10])
tf.print(model(x))
tf.print(model.summary())
我创建了一个名为 DirectLayer 的简单层。我构建了一个使用该层两次的模型。输入层只是为了指定输入数据的形状。
如您所见,您可以轻松指定图层的名称。
汇总函数产生以下结果:
Model: "functional_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 10)] 0
_________________________________________________________________
my_layer (DirectLayer) (None, 10) 20
_________________________________________________________________
direct_layer (DirectLayer) (None, 10) 20
=================================================================
Total params: 40
Trainable params: 40
Non-trainable params: 0
_________________________________________________________________
None
推荐阅读
- azure-logic-apps - 将库导入到逻辑应用 Javascript 代码步骤
- python-3.x - 从字典列表中提取元组列表,一些值用逗号和单引号分隔,一些没有
- java - 我应该在哪里初始化我的开关小部件 Android
- azure - 如何将文件从 azure 存储复制到 vm
- python - 将 Pandas 数据框转换为所需的 python 字典
- ssh - 如何在ansible中检测无法访问的目标主机
- c# - 为什么我无法访问 dbset 中的更新方法?
- c# - C# - Moq - System.Text.Json 自定义 JsonConverter - 如何模拟调用接受 Ref Struct 参数的方法?
- azure-devops - 跳过插件下载 Terraform
- html - 可访问性:在嵌套列表 (ul) 中使用 ARIA role="group"