首页 > 解决方案 > 在 Keras 中使用带有生成器的多处理时,如何为每个 fork 建立单独的数据库连接?

问题描述

我将 Keras 与fit_generator(). 我的生成器连接到数据库(在我的例子中是 MongoDB)以获取每个批次的数据。如果我使用fit_generator()我得到这个警告的多处理标志:

UserWarning: MongoClient opened before fork. Create MongoClient only after forking.

我在以下期间连接到数据库__init__()

class MyCustomGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):
        collection = MagicMongoDBConnector()

    def __len__(self):
        ...

    def __getitem__(self, idx):
        # Using collection to fetch data from mongoDB
        ...

    def on_epoch_end(self):
        ...

我假设我需要为每个纪元建立一个单独的连接,但不幸的是,没有on_epoch_begin(self)可用的回调(如此处所示

所以有两个问题:
如果使用多处理,Keras 如何以及何时分叉生成器?如何摆脱 MongoClient 警告并在每个分叉内连接?

标签: pythontensorflowkerasmultiprocessing

解决方案


我没有要测试的 mongo 数据库,但这可能有效 - 您可以在每个进程的第一个 get-item 上获取集合(连接?)。

class MyCustomGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):
        self.collection = None

    def __len__(self):
        ...

    def __getitem__(self, idx):
        if self.collection is None:
            self.collection = MagicMongoDBConnector()
        # Continue with your code
        # Using collection to fetch data from mongoDB
        ...

    def on_epoch_end(self):
        ...

推荐阅读