python - 如何在 Python 中并行扫描多个超参数集?
问题描述
请注意,我必须扫描比可用 CPU 更多的参数集,因此我不确定 Python 是否会根据 CPU 的可用性或其他情况自动安排 CPU 的使用。
这是我尝试过的,但我收到有关参数的错误:
import random
import multiprocessing
from train_nodes import run
import itertools
envs = ["AntBulletEnv-v0", "HalfCheetahBulletEnv-vo", "HopperBulletEnv-v0", "ReacherBulletEnv-v0",
"Walker2DBulletEnv-v0", "InvertedDoublePendulumBulletEnv-v0"]
algs = ["PPO", "A2C"]
seeds = [random.randint(0, 200), random.randint(200, 400), random.randint(400, 600), random.randint(600, 800), random.randint(800, 1000)]
args = list(itertools.product(*[envs, algs, seeds]))
num_cpus = multiprocessing.cpu_count()
with multiprocessing.Pool(num_cpus) as processing_pool:
processing_pool.map(run, args)
run
接受 3 个参数:env、alg 和 seed。由于某种原因,它没有注册所有 3。
解决方案
中的函数multiprocessing.Pool.map
需要一个参数。调整代码的一种方法是编写一个小的包装函数,将env
、alg
和seed
作为一个参数,将它们分开,并将它们传递给run
.
另一种选择是使用multiprocessing.Pool.starmap
,它允许将多个参数传递给函数。
import random
import multiprocessing
import itertools
envs = [
"AntBulletEnv-v0",
"HalfCheetahBulletEnv-vo",
"HopperBulletEnv-v0",
"ReacherBulletEnv-v0",
"Walker2DBulletEnv-v0",
"InvertedDoublePendulumBulletEnv-v0",
]
algs = ["PPO", "A2C"]
seeds = [
random.randint(0, 200),
random.randint(200, 400),
random.randint(400, 600),
random.randint(600, 800),
random.randint(800, 1000),
]
args = list(itertools.product(*[envs, algs, seeds]))
num_cpus = multiprocessing.cpu_count()
# sample implementation or `run`
def run(env, alg, seed):
# do stuff
return random.randint(0, 200)
def wrapper(env_alg_seed):
env, alg, seed = env_alg_seed
return run(env=env, alg=alg, seed=seed)
# use a wrapper
with multiprocessing.Pool(num_cpus) as processing_pool:
# accumulate results in a dictionary
results = processing_pool.map(wrapper, args)
# use starmap and call `run` directly
with multiprocessing.Pool(num_cpus) as processing_pool:
results = processing_pool.starmap(run, args)
推荐阅读
- git - 如何将非常重(123M)的分支推送到gitlab?
- twilio - Twilio 向 Whatsapp 发送消息
- python - Python:如何使用 NumPy ndarray 创建字符串数组
- javascript - ChartJS 图表不显示 - 折线图
- node.js - 如何在 Loopback 3.0 中增加登录访问令牌 TTL?
- css - 如何在角度 6/7/8 中禁用步进器的涟漪效应,它们是删除标签和按钮等涟漪的选项
- r - 如何用四条线绘制图形?
- swift - 聆听环境变化
- bluetooth - 氟化物蓝牙堆栈的文档
- python - GridSearch 中的最佳参数