首页 > 解决方案 > 如何在 Python 中并行扫描多个超参数集?

问题描述

请注意,我必须扫描比可用 CPU 更多的参数集,因此我不确定 Python 是否会根据 CPU 的可用性或其他情况自动安排 CPU 的使用。

这是我尝试过的,但我收到有关参数的错误:

import random
import multiprocessing
from train_nodes import run
import itertools

envs = ["AntBulletEnv-v0", "HalfCheetahBulletEnv-vo", "HopperBulletEnv-v0", "ReacherBulletEnv-v0",
        "Walker2DBulletEnv-v0", "InvertedDoublePendulumBulletEnv-v0"]
algs = ["PPO", "A2C"]
seeds = [random.randint(0, 200), random.randint(200, 400), random.randint(400, 600), random.randint(600, 800), random.randint(800, 1000)]

args = list(itertools.product(*[envs, algs, seeds]))

num_cpus = multiprocessing.cpu_count()

with multiprocessing.Pool(num_cpus) as processing_pool:
    processing_pool.map(run, args)

run接受 3 个参数:env、alg 和 seed。由于某种原因,它没有注册所有 3。

标签: pythonpython-3.xmultiprocessingpython-multiprocessing

解决方案


中的函数multiprocessing.Pool.map需要一个参数。调整代码的一种方法是编写一个小的包装函数,将envalgseed作为一个参数,将它们分开,并将它们传递给run.

另一种选择是使用multiprocessing.Pool.starmap,它允许将多个参数传递给函数。

import random
import multiprocessing
import itertools

envs = [
    "AntBulletEnv-v0",
    "HalfCheetahBulletEnv-vo",
    "HopperBulletEnv-v0",
    "ReacherBulletEnv-v0",
    "Walker2DBulletEnv-v0",
    "InvertedDoublePendulumBulletEnv-v0",
]
algs = ["PPO", "A2C"]
seeds = [
    random.randint(0, 200),
    random.randint(200, 400),
    random.randint(400, 600),
    random.randint(600, 800),
    random.randint(800, 1000),
]

args = list(itertools.product(*[envs, algs, seeds]))

num_cpus = multiprocessing.cpu_count()

# sample implementation or `run`
def run(env, alg, seed):
    # do stuff
    return random.randint(0, 200)

def wrapper(env_alg_seed):
    env, alg, seed = env_alg_seed
    return run(env=env, alg=alg, seed=seed)

# use a wrapper
with multiprocessing.Pool(num_cpus) as processing_pool:
    # accumulate results in a dictionary
    results = processing_pool.map(wrapper, args)

# use starmap and call `run` directly
with multiprocessing.Pool(num_cpus) as processing_pool:
    results = processing_pool.starmap(run, args)

推荐阅读