python - tf.agent 策略可以返回所有动作的概率向量吗?
问题描述
我正在尝试使用 TF-Agent TF-Agent DQN Tutorial训练强化学习代理。在我的应用程序中,我有 1 个操作,其中包含 9 个可能的离散值(标记为 0 到 8)。下面是输出env.action_spec()
BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(8, dtype=int64))
我想得到包含训练策略计算的所有动作的概率向量,并在其他应用程序环境中做进一步的处理。但是,该策略仅返回log_probability
单个值,而不是所有操作的向量。反正有没有得到概率向量?
from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent
q_net = q_network.QNetwork(
env.observation_spec(),
env.action_spec(),
fc_layer_params=(32,)
)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
my_agent = dqn_agent.DqnAgent(
env.time_step_spec(),
env.action_spec(),
q_network=q_net,
epsilon_greedy=epsilon,
optimizer=optimizer,
emit_log_probability=True,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
my_agent.initialize()
... # training
tf_policy_saver = policy_saver.PolicySaver(my_agent.policy)
tf_policy_saver.save('./policy_dir/')
# making decision using the trained policy
action_step = my_agent.policy.action(time_step)
在dqn_agent.DqnAgent()
DQNAgent中,我设置emit_log_probability=True
了应该定义的Whether policies emit log probabilities or not.
但是,当我运行时action_step = my_agent.policy.action(time_step)
,它会返回
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1], dtype=int64)>, state=(), info=PolicyInfo(log_probability=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>))
我也尝试运行action_distribution = saved_policy.distribution(time_step)
,它返回
PolicyStep(action=<tfp.distributions.DeterministicWithLogProbCT 'Deterministic' batch_shape=[1] event_shape=[] dtype=int64>, state=(), info=PolicyInfo(log_probability=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>))
如果 TF.Agent 中没有这样的 API,有没有办法获得这样的概率向量?谢谢。
后续问题:
如果我理解正确,深度 Q 网络应该state
从状态中获取每个动作的输入并输出 Q 值。我可以将这个 Q 值向量传递给 softmax 函数并计算相应的概率向量。实际上我已经用我自己定制的 DQN 脚本(没有 TF-Agent)做了这样的计算。那么问题就变成了:如何从TF-Agent返回Q值向量?
解决方案
在 TF-Agents 框架中执行此操作的唯一方法是调用Policy.distribution()
方法而不是操作方法。这将返回根据网络的 Q 值计算的原始分布。emit_log_probability=True
唯一影响返回的命名元组的info
属性。请注意,此分布可能会受到您通过的操作约束(如果您这样做)的影响;从而非法行为将被标记为概率为 0(即使原始 Q 值可能很高)。PolicyStep
Policy.action()
此外,如果您想查看实际的 Q 值而不是它们生成的分布,那么恐怕如果不直接作用于您的代理附带的 Q 网络,就无法做到这一点(这也是附加到Policy
代理产生的对象上)。如果您想了解如何正确调用该 Q 网络,我建议您在此处查看QPolicy._distribution()
该方法的执行方式。
请注意,这些都不能使用预先实现的驱动程序来完成。您必须要么显式构建自己的集合循环,要么实现自己的 Driver 对象(这基本上是等效的)。
推荐阅读
- python - 如何设置默认的 conda 环境,以便每当我打开终端时它应该被激活而不是 base?
- r - 如何将 r 中的数据帧分成相等数量的记录组,并在两个数据帧中随机平均分割数据
- python - ORtools 为每个学生分配一门常规课程或三门特殊课程
- javascript - 从毫秒转换日期而不本地化它们
- sql - 加入具有不同前缀和名称的列
- datetime - 将时间范围添加到参数中的单个日期 - SQL Report Builder/SSRS
- c# - lambda表达式到简单的c#
- java - Java - 检查部分对象相等性
- javascript - 在 Javascript 中将比较 (< >) 运算符与非数字字符串一起使用
- javascript - 如何从角度的表单对象中获取验证错误?