python - 有没有办法将参数传递给 optuna 中的多个工作?
问题描述
我正在尝试使用 optuna 搜索超参数空间。
在一个特定的场景中,我在一台带有几个 GPU 的机器上训练一个模型。模型和批量大小允许我每 1 个 GPU 运行 1 次训练。因此,理想情况下,我想让 optuna 将所有试验分布在可用的 GPU 上,以便每个 GPU 上始终运行 1 个试验。
在它说的文档中,我应该在一个单独的终端中为每个 GPU 启动一个进程,例如:
CUDA_VISIBLE_DEVICES=0 optuna study optimize foo.py objective --study foo --storage sqlite:///example.db
我想避免这种情况,因为在那之后整个超参数搜索会继续进行多轮。我不想总是手动启动每个 GPU 的进程,检查所有进程何时完成,然后开始下一轮。
我看到study.optimize
有一个n_jobs
说法。乍一看,这似乎是完美的。
例如我可以这样做:
import optuna
def objective(trial):
# the actual model would be trained here
# the trainer here would need to know which GPU
# it should be using
best_val_loss = trainer(**trial.params)
return best_val_loss
study = optuna.create_study()
study.optimize(objective, n_trials=100, n_jobs=8)
这会启动多个线程,每个线程都开始训练。但是,内部的培训师objective
不知何故需要知道它应该使用哪个 GPU。有什么诀窍可以做到这一点吗?
解决方案
经过几次精神崩溃后,我发现我可以使用multiprocessing.Queue
. 要将其纳入目标函数,我需要将其定义为 lambda 函数或类(我猜部分也可以)。例如
from contextlib import contextmanager
import multiprocessing
N_GPUS = 2
class GpuQueue:
def __init__(self):
self.queue = multiprocessing.Manager().Queue()
all_idxs = list(range(N_GPUS)) if N_GPUS > 0 else [None]
for idx in all_idxs:
self.queue.put(idx)
@contextmanager
def one_gpu_per_process(self):
current_idx = self.queue.get()
yield current_idx
self.queue.put(current_idx)
class Objective:
def __init__(self, gpu_queue: GpuQueue):
self.gpu_queue = gpu_queue
def __call__(self, trial: Trial):
with self.gpu_queue.one_gpu_per_process() as gpu_i:
best_val_loss = trainer(**trial.params, gpu=gpu_i)
return best_val_loss
if __name__ == '__main__':
study = optuna.create_study()
study.optimize(Objective(GpuQueue()), n_trials=100, n_jobs=8)
推荐阅读
- javascript - 浏览器窗口不会用 window.location 重新加载?路由器问题?
- html - 谷歌 SEO 标题和元描述
- networking - 创建6in4服务器,但客户端无法正常上网
- php - 如何仅导出 Wordpress 帖子、评论、页面和媒体文件?
- javascript - ReactJS 中的多个上下文
- list - Haskell,可遍历从 Maybe [list] 获取值
- javascript - 如何在动态生成的 mat-table 中将字符串转换为类型日期
- reactjs - Typescript-React 状态:元素隐式具有“任何”类型,因为“状态”类型没有索引签名
- java - 尝试签署openjdk时出现“已签名”错误
- javascript - 对 ASP.NET MVC 控制器的 jQuery Ajax 请求成功后如何运行 JavaScript 函数?