首页 > 解决方案 > Keras:在批处理生成器中使用多核

问题描述

我正在使用以下代码来训练模型keras==2.1.6

history = model.fit_generator(
    generator=generator.batch_generator(is_train=True),
    epochs=config.N_EPOCHS,
    steps_per_epoch=100,
    validation_data=generator.batch_generator(is_train=False),
    validation_steps=10,
    verbose=1,
    shuffle=False,
    workers=args.n_thread,
    max_queue_size=args.n_thread*2,
    use_multiprocessing=True,
    callbacks=callbacks)

当我没有指定时,无论指定什么数字,use_multiprocessing=True我都有大约 200% 的 CPU 负载,所以在这种情况下参数是无用的?htopargs.n_threadworkers

默认情况下它使用基于线程的处理?但是为什么线程数应该和workers参数有关,却没有增加CPU负载呢?它与GIL有关吗?

    use_multiprocessing: Boolean.
        If `True`, use process-based threading.
        If unspecified, `use_multiprocessing` will default to `False`.
        Note that because this implementation relies on multiprocessing,
        you should not pass non-picklable arguments to the generator
        as they can't be passed easily to children processes.

此外,当使用时use_multiprocessing=True,例如args.n_thread=6htop我看到 12 CPU 系统的平均负载超过 6,是否因为数据集预处理有一些并行代码,即一些 numpy 操作,是否可以精确使用 6 CPU(600% CPU 负载)?

更新: 查看 keras 代码: https ://github.com/keras-team/keras/blob/30fe4ff1f12ff0c45bac8738b4d2690eadd056b2/keras/utils/data_utils.py#L409 它使用multiprocessing.pool.ThreadPoolvsmultiprocessing.Pool在两种情况下它都使用workers参数,那么差异在哪里来自?

标签: pythonlinuxkeraspython-multiprocessingpython-multithreading

解决方案


推荐阅读