python - 批量输入 Keras 共享参数
问题描述
我正在构建网络来对一组 N 个输入进行排名。理想情况下,它们应该全部同时输入并共享参数。他们的目标向量应该是一个 N-hot 向量来匹配输入。
这意味着我的输入应该是 (Batch_size, N, sequence_length, feature_length)
但是对于任何大于 3 维的输入,keras 都会抛出错误,如下所示:
ValueError:输入 0 与层 lstm_2 不兼容:预期 ndim=3,发现 ndim=4
我目前的 keras 设置是:
x = Input(shape=(72,300))
aux_input = Input(shape=(72, 4))
probs = Input(shape=(1,))
#dim_red_1 = Dense(100)(x)
dim_red_2 = Dense(20, activation='tanh')(x)
cat = concatenate([dim_red_2, aux_input])
encoded = LSTM(64)(cat)
cat2 = concatenate([encoded, probs])
output = Dense(1, activation='sigmoid')(cat2)
lstm_model = Model(inputs=[x, aux_input, probs], outputs=output)
lstm_model.compile(optimizer='ADAM', loss='binary_crossentropy', metrics=['accuracy'])
有没有办法通过 Keras 实现这一目标?
解决方案
尽管您的代码看起来不错,但请确保导入正确的包:
import numpy as np
from tensorflow.python.keras import Input
from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.layers import Dense, LSTM, Concatenate
a = np.zeros(shape=[1000, 72, 300])
b = np.zeros(shape=[1000, 72, 4])
c = np.zeros(shape=[1000, 1])
d = np.zeros(shape=[1000, 1])
x = Input(shape=(72, 300))
aux_input = Input(shape=(72, 4))
probs = Input(shape=(1,))
dim_red_2 = Dense(20, activation='tanh')(x)
cat = Concatenate()([dim_red_2, aux_input])
encoded = LSTM(64)(cat)
cat2 = Concatenate()([encoded, probs])
output = Dense(1, activation='sigmoid')(cat2)
lstm_model = Model(inputs=[x, aux_input, probs], outputs=output)
lstm_model.compile(optimizer='ADAM', loss='binary_crossentropy', metrics=['accuracy'])
lstm_model.summary()
lstm_model.fit([a, b, c], d, batch_size=256)
输出:
256/1000 [======>.......................] - ETA: 2s - loss: 0.6931 - acc: 1.0000
512/1000 [==============>...............] - ETA: 1s - loss: 0.6910 - acc: 1.0000
768/1000 [======================>.......] - ETA: 0s - loss: 0.6885 - acc: 1.0000
1000/1000 [==============================] - 1s 1ms/step - loss: 0.6859 - acc: 1.00
推荐阅读
- android - Kotlin如何实现android setOnClickListener之类的语法
- hbase - clone_snapshot 和 copyTable 有什么区别?
- python - 我可以从 Power BI 中的个人 Python 脚本调用函数吗
- jenkins - Jenkins 在 Mac 系统上构建期间未能生成新进程
- css - 当其他列只有一个项目时,如何对一列中的项目数组使用网格布局?
- ssl - Terraform Init/apply/destroy - SSL 连接问题
- php - 如何在后端的laravel中从数据库中获取数据的分页
- r - spatstat 中 Cox 过程模型中簇大小的含义
- reactjs - react中如何根据props有条件地渲染数据属性
- sql-server - SQL 触发器 - 分配金额