python - 如何使用功能 API 在 keras 中实现合并功能
问题描述
print("Building model...")
ques1_enc = Sequential()
ques1_enc.add(Embedding(output_dim=64, input_dim=vocab_size, weights=[embedding_weights], mask_zero=True))
ques1_enc.add(LSTM(100, input_shape=(64, seq_maxlen), return_sequences=False))
ques1_enc.add(Dropout(0.3))
ques2_enc = Sequential()
ques2_enc.add(Embedding(output_dim=64, input_dim=vocab_size, weights=[embedding_weights], mask_zero=True))
ques2_enc.add(LSTM(100, input_shape=(64, seq_maxlen), return_sequences=False))
ques2_enc.add(Dropout(0.3))
model = Sequential()
model.add(Merge([ques1_enc, ques2_enc], mode="sum"))
model.add(Dense(2, activation="softmax"))
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
print("Building model costs:", time.time() - start)
print("Training...")
checkpoint = ModelCheckpoint(filepath=os.path.join("C:/Users/", "quora_dul_best_lstm.hdf5"), verbose=1, save_best_only=True)
model.fit([x_ques1train, x_ques2train], ytrain, batch_size=32, epochs=1, validation_split=0.1, verbose=2, callbacks=[checkpoint])
print("Training neural network costs:", time.time() - start)
我想将上述代码转换为 keras 中的功能 API,因为不支持顺序 API Merge() 函数。我已经尝试了很长时间,但几乎没有错误。关于属性的详细信息:ques_pairs 包含预处理数据,word2index 包含字数,seq_maxlen 包含问题一或二的最大长度。我试图在 Quora Question Pair Dataset https://www.kaggle.com/c/quora-question-pairs上实现这个模型
解决方案
我会给你一个小例子,你可以将它应用到你自己的模型中:
from keras.layers import Input, Dense, Add
input1 = Input(shape=(16,))
output1 = Dense(8, activation='relu')(input1)
output1 = Dense(4, activation='relu')(output1) # Add as many layers as you like like this
input2 = Input(shape=(16,))
output2 = Dense(8, activation='relu')(input2)
output2 = Dense(4, activation='relu')(output2) # Add as many layers as you like like this
output_full = Add()([output1, output2])
output_full = Dense(1, activation='sigmoid')(output_full) # Add as many layers as you like like this
model_full = Model(inputs=[input1, input2], outputs=output_full)
您需要首先Input
为每个模型部件定义一个,然后向两个模型添加层(如代码所示)。然后您可以使用Add
图层添加它们。最后,您调用Model
输入层和输出层的列表。
model_full
然后可以像任何其他模型一样编译和训练。
推荐阅读
- excel - Excel(如果声明,也许?)生日列表或 Vlookup
- openshift - 使用 OpenShift Origin 3.11 的图像推/拉非常慢
- postcss - 如何使用汇总处理来自特定 .scss 文件(而不是 entry.js)的 css/scss
- javascript - Javascript:在新数组中映射两个数组的值
- apache-spark - 如何解决这个错误 org.apache.spark.sql.catalyst.errors.package$TreeNodeException
- java - 骆驼卡夫卡 SSL
- c# - wpf DataGridTemplateColumn.CellTemplate 如何启用/禁用复选框
- awk - 根据不同的分隔符拆分列并应用条件
- ssl - Comodo SSL - 本地网络
- python - 什么是更 Pythonic 的方式来测试我的部分代码?