python - 火炬:具有显式多处理的自定义数据加载器
问题描述
我有一个自定义数据加载器。我已经明确使用 pythonmultiprocessing
在我的自定义数据加载器中并行化数据预处理。我在我的数据加载器中使用了 8 个工人( num_threads
) 。multiprocessing
我想知道,这将如何影响我的torch.utils.data.DataLoader
通话?参数是否会num_workers
设置为 8?或者我可以将其保留为 0 吗?我的自定义加载器看起来像这样
def processParams(params):
<some operations on params>
return params
def processParamsParallel(params, pool):
results = pool.map(processParams, params)
return results
class DataLoader(object):
def __init__(self, params, maxId):
self.params = params
self.id = 0
self.maxId = maxId
self.pool = Pool(processes=8)
def __iter__(self):
while self.id<=self.maxId:
if self.id==self.maxId:
self.id = 0
results = processParamsParallel(self.params, self.pool)
self.id+=1
yield results
这是我正在尝试做的一个非常粗略的例子。现在在火炬通话中
dl = DataLoader(params, 50)
dl_torch = torch.utils.data.Dataloader(dl, num_workers = <what_here?>, prefetch_factor = <what_here?>)
在类似的说明中,如果不是由 torch 调用而是由自定义 dataLoader 本身设置,将如何prefetch_factor
影响?num_workers
来自 torch.utils.data
prefetch_factor - 每个工作人员预先加载的样本数。2
意味着将在所有工作人员中预取总共 2 * num_workers 个样本。(默认2
:)
先感谢您!