python-3.x - Keras 功能 API 比 Sequential 慢 / 没有改进
问题描述
已解决!(必须在顺序模型中设置 trainable=true)
我目前正在将我的 Keras 模型从 Sequential 更改为功能 API。虽然 Sequential 模型在 10 个 epoch 之后确实提高了 1 的准确度,但功能性 API 模型甚至没有达到 0.7 并且没有进一步提高。除了输入层,两个网络应该是相同的。
顺序:
model = Sequential()
model.add(Embedding(20000, 256,input_length = 30))
model.add(SpatialDropout1D(0.4))
model.add(LSTM(256, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(1,activation='sigmoid'))
model.compile(loss = 'binary_crossentropy', optimizer=Adam(lr=0.0001),metrics = ['accuracy'])
print(model.summary())
输出是:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_6 (Embedding) (None, 30, 256) 5120000
_________________________________________________________________
spatial_dropout1d_5 (Spatial (None, 30, 256) 0
_________________________________________________________________
lstm_5 (LSTM) (None, 256) 525312
_________________________________________________________________
dense_6 (Dense) (None, 1) 257
=================================================================
Total params: 5,645,569
Trainable params: 5,645,569
Non-trainable params: 0
_________________________________________________________________
None
对于功能 API:
inputs = Input(shape=(31,))
embed = Embedding(20000, 256, trainable=False)(inputs)
drop = (SpatialDropout1D(0.4))(embed)
lstm = LSTM(256, dropout=0.3, recurrent_dropout=0.3)(drop)
acti = Dense(1,activation='sigmoid')(lstm)
model = Model(inputs=inputs, outputs=acti)
model.compile(loss = 'binary_crossentropy', optimizer=Adam(lr=0.0001),metrics = ['accuracy'])
print(model.summary())
结果
Model: "model_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_8 (InputLayer) (None, 31) 0
_________________________________________________________________
embedding_7 (Embedding) (None, 31, 256) 5120000
_________________________________________________________________
spatial_dropout1d_6 (Spatial (None, 31, 256) 0
_________________________________________________________________
lstm_6 (LSTM) (None, 256) 525312
_________________________________________________________________
dense_7 (Dense) (None, 1) 257
=================================================================
Total params: 5,645,569
Trainable params: 525,569
Non-trainable params: 5,120,000
_________________________________________________________________
None
我是否监督了某些事情,或者有人可以解释我的结果吗?
解决方案
推荐阅读
- git - 我试图提交,但它显示此错误
- java - 在静态上下文中将弃用的 Java 日期替换为 Calendar.set 或 GregorianCalendar.set
- ffmpeg - 来自图像的 FFMPEG 窗帘效果幻灯片
- java - 无法使用 $value 获取消息的 MIME 版本
- javascript - JavaScript - 替换嵌套对象的值,而不影响整个对象
- php - 使用 array_merge 进行 Laravel 分页
- python - 使用 python 对字符串进行聚类的最佳方法
- odoo - odoo 过滤域运算符“child_of”的解释及其对“in”运算符的偏好
- c - 如何在 Windows 中更改控制台程序以支持 unicode?
- linux - 我正在尝试运行 psipred 命令,但它显示错误“不匹配'`'。'