首页 > 解决方案 > 如何使用 PyTorch DataLoader 进行强化学习?

问题描述

我正在尝试在 PyTorch 中建立一个通用的强化学习框架,以利用所有利用 PyTorch DataSet 和 DataLoader 的高级实用程序,如 Ignite 或 FastAI,但我遇到了一个具有动态性质的阻止程序强化学习数据:

到目前为止,我的 Google 和 StackOverflow 搜索都取得了成果。这里有人知道将 DataLoader 或 DataSet 与强化学习一起使用的现有解决方案或解决方法吗?我讨厌放弃对依赖于那些的所有现有库的访问。

标签: pytorchreinforcement-learningdataloader

解决方案


是一个基于 PyTorch 的框架,是来自 Facebook 的东西。

当涉及到您的问题(毫无疑问是崇高的追求)时:

您可以轻松地创建一个torch.utils.data.Dataset依赖于任何东西,包括模型,像这样(原谅弱抽象,这只是为了证明一点):

import typing

import torch
from torch.utils.data import Dataset


class Environment(Dataset):
    def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
        self.current_state = initial_state
        self.actor: torch.nn.Module = actor
        self.max_interactions: int = max_interactions

    # Just ignore the index
    def __getitem__(self, _):
        self.current_state = self.actor.update(self.current_state)
        return self.current_state.get_data()

    def __len__(self):
        return self.max_interactions

假设,torch.nn.Module-like 网络具有某种update变化的环境状态。总而言之,它只是一个 Python 结构,所以你可以用它来建模很多东西。

您可以指定max_interactions为几乎infinite,或者如果需要在训练期间使用一些回调(可能__len__会在整个代码中多次调用),您可以动态更改它。环境可以进一步提供batches而不是样品。

torch.utils.data.DataLoaderbatch_sampler参数,在那里你可以生成不同长度的批次。由于网络不依赖于第一个维度,您也可以从那里返回您想要的任何批量大小。

顺便提一句。如果每个样本的长度不同,则应使用填充,不同的批量大小与此无关。


推荐阅读