keras - Keras:当我使用 fit_generator 时,损失不会改变。但合身效果很好
问题描述
我有一个要在 Keras 上训练的大型数据集,为了避免“内存错误”,我尝试使用 fit_generator 函数。奇怪的是,当我使用 fit_generator 时,损失似乎没有改变,但 fit 函数运行良好。
无论 fir_gen 或 fit 函数是什么,数据集和其他代码都是相同的。
这是一个 lstm - seq2seq 模型。
我搜索了很长时间,发现了另外两个和我一样的问题。1. Keras: network doesn't train with fit_generator() 根据这篇文章,我改变了我的batch_size,但它不起作用。当我尝试将“产量”更改为返回时,它会给我一个错误。2 、Keras没有使用fit_generator()进行训练 这篇文章其实没有答案。
model.fit_generator(generate_train(batch_size=200),
steps_per_epoch=5,
epochs=100,
verbose=1,
callbacks=callbacks_list,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=False,
initial_epoch=initial_epoch
)
def generate_train(batch_size):
steps=0
context_ = np.load(main_path + 'middle_data/context_indexes.npy')
final_target_ = np.load(main_path + 'middle_data/target_indexes.npy')
context_ = context_[:1000]
final_target_ = final_target_[:1000]
while True:
context = context_[steps:steps+batch_size]
final_target = final_target_[steps:steps+batch_size]
processing. . .
outs = . . .
yield [context, final_target], outs
steps += batch_size
if steps == 1000:
steps = 0
当我使用 fit() 时:
Epoch 1/30 loss: 2.5948 - acc: 0.0583
Epoch 2/30 loss: 2.0840 - acc: 0.0836
Epoch 3/30 loss: 1.9226 - acc: 0.0998
Epoch 4/30 loss: 1.8286 - acc: 0.1086
Epoch 5/30 loss: 1.7399 - acc: 0.1139
Epoch 6/30 loss: 1.6509 - acc: 0.1192
Epoch 7/30 loss: 1.5518 - acc: 0.1247
Epoch 8/30 loss: 1.4330 - acc: 0.1316
Epoch 9/30 loss: 1.3117 - acc: 0.1454
Epoch 10/30 loss: 1.1872 - acc: 0.1657
Epoch 11/30 loss: 1.0720 - acc: 0.1893
Epoch 12/30 loss: 0.9589 - acc: 0.2169
. . .
当我使用 fit_generator() 时:
Epoch 1/100 loss: 3.4926 - acc: 0.0370
Epoch 2/100 loss: 2.7239 - acc: 0.0388
Epoch 3/100 loss: 2.6030 - acc: 0.0389
Epoch 4/100 loss: 2.5727 - acc: 0.0408
Epoch 5/100 loss: 2.5628 - acc: 0.0366
Epoch 6/100 loss: 2.5513 - acc: 0.0420
Epoch 7/100 loss: 2.5475 - acc: 0.0387
Epoch 8/100 loss: 2.5508 - acc: 0.0407
Epoch 9/100 loss: 2.5490 - acc: 0.0418
Epoch 10/100 loss: 2.5419 - acc: 0.0401
解决方案
我已经找到了一种解决方案。
我通过'yield'打印来自生成器的所有数据,我发现我错误地重写了每个时期的训练数据。
所以原来的数据在变化,导致了稳定的损失。
推荐阅读
- shortest-path - 初等最短路径问题与最短路径问题
- c# - 向 SOAP 消息添加安全标头
- graphql - Github v4 GraphQL API - 访问 Marketplace 应用结果 (Travis CI)
- reactjs - 嵌套路由链接到两个不同的组件
- azure - 如何使用 GUI 工具浏览 Azure Data Lake gen 2
- javascript - 如何使用 vue.js 嵌套 for 循环遍历两个带有 v-for 的数组
- r - 从向量中递归搜索文件名R
- php - 如何在此代码中格式化“其他”?
- python - Python pandas 转置汇总内容
- date - 如何在 Google Apps 脚本中添加一天事件的开始日期