python - Understanding custom policies in stable-baselines3
问题描述
I was trying to understand the policy networks in stable-baselines3 from this doc page.
As explained in this example, to specify custom CNN feature extractor, we extend
BaseFeaturesExtractor
class and specify it inpolicy_kwarg.features_extractor_class
with first paramCnnPolicy
:model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs)
Q1. Can we follow same approach for custom MLP feature extractor?
As explained in this example, to specify custom MLP feature extractor, we extend
ActorCriticPolicy
class and override_build_mlp_extractor()
and pass it as first param:class CustomActorCriticPolicy(ActorCriticPolicy): ... model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
Q2. Can we follow same approach for custom CNN feature extractor?
I feel either we can have CNN extractor or MLP extractor. So it makes no sense to pass
MlpPolicy
as first param to model and then specify CNN feature extractor inpolicy_kwarg.features_extractor_class
as in this example. This result in following policy (containing bothfeatures_extractor
andmlp_extractor
), which I feel is incorrect:ActorCriticPolicy( (features_extractor): Net( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (fc3): Linear(in_features=384, out_features=512, bias=True) ) (mlp_extractor): MlpExtractor( (shared_net): Sequential( (0): Linear(in_features=512, out_features=64, bias=True) (1): ReLU() ) (policy_net): Sequential( (0): Linear(in_features=64, out_features=32, bias=True) (1): ReLU() (2): Linear(in_features=32, out_features=16, bias=True) (3): ReLU() ) (value_net): Sequential( (0): Linear(in_features=64, out_features=32, bias=True) (1): ReLU() (2): Linear(in_features=32, out_features=16, bias=True) (3): ReLU() ) ) (action_net): Linear(in_features=16, out_features=7, bias=True) (value_net): Linear(in_features=16, out_features=1, bias=True) )
Q3. Am I correct with this understanding? If yes, then is one of the MLP or CNN feature extractor ignored?
解决方案
在我浏览完所有库代码后我能说什么。仅在实现的默认类CnnPolicy
中有所不同。这仅在您不尝试创建自定义类的情况下才有意义。让我试着解释一下,我们可以看到两种类型的策略:MlpPolicy
BaseFeatureExtraction
BaseFeatureExtraction
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
我们可以看到这class ActorCriticCnnPolicy(ActorCriticPolicy)
只是基于ActorCriticPolicy
,我们可以在参数中看到以下内容:
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN
哪里NatureCNN
是BaseFeatureExtractor
CNN 层的简单实现。
现在让我们来回答你的问题!这是一件事,我们总是有以下模型结构: FeatureExtractor -> MlpExtractor -> Policy\Value nets
您可以在 policy_kwargs中指定feature_exctractor_class
和。net_arch
默认 feature_extractor_class 基于您对CnnPolicy
or的选择MlpPolicy
,但如果您指定自己的类,则没有区别。所以,你可以只使用MlpPolicy
.
net_arch
例如,如果您指定,'net_arch':[64, dict(pi=[32, 16], vf=[32, 16])]
您将有 1 个用于连接 feature_extractor 输出的 mlp_extractor 密集层。
总而言之,我的建议是遵循您在Q3中找到的示例:
feature_extractor
使用此网络的任何结构指定您自己的 cnn 或密集或任何您想要的。的输出feature_extractor
应该是FC层。net_acrh
在以下结构中指定参数:[x1, x2, ..., dict(pi=[px1, px2, ..], vf=[vx1, vx2, ...])]
其中第一部分是指定的 FC 层大小,mlp_extractor
第二部分(在 dict 内部)对应于策略和价值网络的大小。
推荐阅读
- django - 具有两个字段的 django 字段。需要意见
- python - 特定行列的平均值
- internet-explorer - 我的 IE 浏览器版本是 11.239.* 但是当我检查代码时它显示版本为 IE 版本 7
- javascript - 如何访问forEach之后获得的对象的属性?
- entity-framework-core - 添加前键时列无效
- angular - NullInjectorError:FormBuilder 没有提供程序
- java - 如何从 Firebase 加载孩子
- appkit - 如何使用 NSTextView 进行批量显示
- java - 合并两个排序数组
- actions-on-google - 无法在工作资料上使用 Google 助理应用