python - Keras 的并行多图像生成器
问题描述
我有一个 Keras 模型,需要使用多个ImageGenerator
来提供来自多个源的数据(但该模型仍然只有 1 个输入)。
我创建了一个可以做到这一点的函数(实际上我使用了 5-6 个生成器)
def multiple_generator(batch_size):
genX1 = train_datagen.flow_from_directory('./directory1',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical')
genX2 = train_datagen.flow_from_directory('./directory2',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical')
while True:
X1i = genX1.next()
X2i = genX2.next()
yield np.concatenate([X1i[0],X2i[0]],axis = 0),\
np.concatenate([X1i[1],X2i[1]],,axis = 0)
但是当开始训练时,训练时间比使用单个生成器要长得多。例如,在单个生成器中,每个 epoch 只需要 120s batch_size
,但使用 时multiple_generator
,batch_size = 64
需要 5 分钟,128 每个 epoch 需要 12 分钟。
我认为迭代多个生成器的任务可能会减慢训练时间,我认为并行函数是这样的:
def multiple_generator(batch_size):
pool = Pool(processes=2)
genX1 = pool.apply(train_datagen.flow_from_directory('./directory1',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical'))
genX2 = pool.apply(train_datagen.flow_from_directory('./directory2',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical'))
while True:
X1i = genX1.next()
X2i = genX2.next()
yield np.concatenate([X1i[0],X2i[0]],axis = 0),\
np.concatenate([X1i[1],X2i[1]],,axis = 0)
但它返回错误
Traceback (most recent call last):
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-26-2d1d4ddfacf1>", line 4, in <module>
class_mode='categorical'))
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/pool.py", line 259, in apply
return self.apply_async(func, args, kwds).get()
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/pool.py", line 644, in get
raise self._value
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
put(task)
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.lock objects
我在处理多进程管理方面没有太多经验,您对此有什么解决方案吗?其他加速发电机的策略总是受欢迎的。非常感谢
解决方案
推荐阅读
- sql - JOIN ON 子句中的 Validate_conversion 函数 - ORA-00932:不一致的数据类型:预期 NUMBER 得到 CHAR
- javascript - Google Charts:在循环中设置多个侦听器会覆盖以前的侦听器
- symfony - 自定义规范器的 API 平台文档
- hibernate - 使用 Spring Boot JPA 检查 MySQL DB 连接 - 如何设置 Hikari 的超时
- date - SAP 网关服务器中的 DatePicker 值不匹配
- javascript - 使用JS在websql中插入表单数据创建动态表单
- keras - 如何引用多输出模型的一个输出
- pepper - Pepper 机器人参与和动画
- html - 来自检查器 HTML 的不同 HTTP 响应
- flutter - 下一页流未在 Flutter 上更新