python - 如何使生成器线程安全?
问题描述
我有一个看起来像这样的生成器:
def data_generator(data_file, index_list,....):
orig_index_list = index_list
while True:
x_list = list()
y_list = list()
if patch_shape:
index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
patch_overlap, patch_start_offset,pred_specific=pred_specific)
else:
index_list = copy.copy(orig_index_list)
while len(index_list) > 0:
index = index_list.pop()
add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
skip_blank=skip_blank, permute=permute)
if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
x_list = list()
y_list = list()
我的数据集大小为 55GB,并存储为 .h5 文件 (data.h5)。读取数据时非常慢。一个 epoch 需要 7000 秒,并且在 6 个 epoch 之后出现分段错误。
我想如果我设置multi_processing = False
它workers > 1
会加快读取数据的速度:
model.fit(multi_processing = False, workers = 8)
但是当我这样做时,我收到以下错误:
RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when use_multiprocessing=False, workers > 1.
有没有办法让我的生成器线程安全?或者有没有其他有效的方法来生成这些数据?
解决方案
我相信LockedIterator
我在上面的评论中引用的类是不正确的,应该像下面的例子那样编码:
import threading
class LockedIterator(object):
def __init__(self, it):
self.lock = threading.Lock()
self.it = iter(it)
def __iter__(self): return self
def __next__(self):
with self.lock:
return self.it.__next__()
def gen():
for x in range(10):
yield x
new_gen = LockedIterator(gen())
def worker(g):
for x in g:
print(x, flush=True)
t1 = threading.Thread(target=worker, args=(new_gen,))
t2 = threading.Thread(target=worker, args=(new_gen,))
t1.start()
t2.start()
t1.join()
t2.join()
印刷:
0
1
23
4
5
6
7
8
9
如果您想保证打印输出每行打印一个值,那么我们还需要将一个threading.Lock
实例传递给每个线程并print
在该锁的控制下发出语句,以便序列化打印。
推荐阅读
- apache-spark - 如何在使用 Spark 进行 sqooping 时处理记录中的额外 '\n'?
- firebase - 如何比较 Firestore 查询中的多个字段?
- pine-script - Pine 脚本中的累积线
- azure-ad-b2c - Azure AD B2C - 运行“自助密码重置”后,自定义策略“会话”处于不正确状态
- google-tag-manager - 针对不同产品类别的 Google Ads 转化跟踪,如何?
- android - api 30 android 无法访问电子邮件和拨号器应用程序
- python - AttributeError:模块'flask'没有属性'route'
- javascript - 用于检查是否将新子节点添加到目标节点的事件侦听器
- java - 递归插入二叉搜索树JAVA
- python - 如何用两个类别制作条形图?