keras - Keras 层中的链接权重
问题描述
假设我将输入分成两个相同大小的部分 I1、I2,并且我希望在我的 keras 网络上具有以下结构——I1->A1,I2->A2,然后是 [A1,A2]->B,其中 B 是输出节点。我可以使用1中的组来执行此操作。但是,我想要求 I1->A1 的连接权重(和其他激活参数)与 I2->A2 的连接权重相同,即我希望 1 和 2 之间具有对称性。(请注意,我不需要 [A1,A2]->B 的对称性。)
解决方案
如果我正确理解了您的问题,则 input_1 到 A_1 和 input_2 到 A_2 的映射已经一个接一个地完成,因为您希望两个输入的映射函数相同。在这种情况下,您可能会考虑使用 RNN,但如果您的输入彼此独立,您可能会考虑TimeDistributed
在 Keras 中使用 wrapper。下面的示例将采用两个输入,并使用Dense
层将输入一一映射,因此Dense
共享权重:
from keras.models import Model
from keras.layers import Input, Dense, TimeDistributed, Concatenate, Lambda
x_dim = 5
hidden_dim = 8
x1 = Input(shape=(1,x_dim,))
x2 = Input(shape=(1,x_dim,))
concat = Concatenate(axis=1)([x1, x2])
hidden_concat = TimeDistributed(Dense(hidden_dim))(concat)
hidden1 = Lambda(lambda x: x[:,:1,:])(hidden_concat)
hidden2 = Lambda(lambda x: x[:,1:,:])(hidden_concat)
model = Model(inputs=[x1,x2], outputs=[hidden1, hidden2])
model.summary()
>>>
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_33 (InputLayer) (None, 1, 5) 0
__________________________________________________________________________________________________
input_34 (InputLayer) (None, 1, 5) 0
__________________________________________________________________________________________________
concatenate_17 (Concatenate) (None, 2, 5) 0 input_33[0][0]
input_34[0][0]
__________________________________________________________________________________________________
time_distributed_9 (TimeDistrib (None, 2, 8) 48 concatenate_17[0][0]
__________________________________________________________________________________________________
lambda_8 (Lambda) (None, 1, 8) 0 time_distributed_9[0][0]
__________________________________________________________________________________________________
lambda_9 (Lambda) (None, 1, 8) 0 time_distributed_9[0][0]
==================================================================================================
Total params: 48
Trainable params: 48
Non-trainable params: 0
推荐阅读
- javascript - 每次单击新链接时,href链接都会复制路径中的文件夹
- reactjs - 用样式组件反应 SSR
- flutter - Flutter FutureProvider 值未在 Builder 方法中更新
- javascript - Three.js 和 Django?
- unreal-engine4 - UE4 Open Level 节点在熟版本中不起作用
- python - 什么参数被传递给这个函数?
- reactjs - 使用 create-react-app 创建新的 react 应用程序时遇到问题
- javascript - 如果我们有任何方法可以从 html 运行带有来自外部 html 页面的 onclick 事件的函数?
- python - 如果满足条件,如何从字典中获取值?
- ios - Swift 5 我们如何解析嵌套在字典中的字典,嵌套在另一个字典中?