首页 > 解决方案 > 从 ray.tune 中提取代理

问题描述

我一直在使用 azure 机器学习来训练使用 ray.tune 的强化学习代理。

我的训练功能如下:

    tune.run(
        run_or_experiment="PPO",
        config={
            "env": "Battery",
            "num_gpus" : 1,
            "num_workers": 13,
            "num_cpus_per_worker": 1,
            "train_batch_size": 1024,
            "num_sgd_iter": 20,
            'explore': True,
            'exploration_config': {'type': 'StochasticSampling'},
        },
        stop={'episode_reward_mean': 0.15},
        checkpoint_freq = 200,
        local_dir = 'second_checkpoints'
        
    )

如何从检查点中提取代理,以便可以将我的健身房环境中的操作可视化,如下所示:

while not done:
    action, state, logits = agent.compute_action(obs, state)
    obs, reward, done, info = env.step(action)
    episode_reward += reward
    print('action: ' + str(action) + 'reward: ' + str(reward))


我知道我可以使用这样的东西:

analysis = tune.run('PPO",config={"max_iter": 10}, restore=last_ckpt)

但我不确定如何从 tune.run 中存在的代理中提取计算操作(和奖励)。

标签: reinforcement-learningrayrllib

解决方案


调谐运行用于训练模型。培训结束后,您应该有一些检查点文件。这些文件可以加载,然后在您的环境中播放。

agent = ppo.PPOTrainer(config=config, env=env_name)
agent.restore(checkpoint_file)
obs = env.reset()
action = agent.compute_action(obs)
obs, reward, done, info = env.step(action)

推荐阅读