python - Calculating loss from action and reward in Tensorflow
问题描述
I'm trying to calculate loss in an RL project with 3 discrete actions. I have the output prediction of my model for (from tf.layers.dense()
) (e.g. 3 possible actions, batch size 2):
[[10, 20.2, 4.3],
[5, 3, 8.9]]
I have a the action that was taken by the agent (e.g.):
[[1],
[2]]
And I have the reward for taking that action from the environment (e.g):
[[30.0],
[15.0]]
I want to calculate the loss for the taken action, using the action as an index and the reward. I don't have any information for the actions that weren't taken. If it were just calculating the difference I'd expect the loss (from the previous examples) to be:
[[0, 9.8, 0],
[0, 0, 6.1]]
I've tried:
updated = tf.scatter_update(logits, action, reward)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=updated, logits=logits)
But this gives AttributeError: 'Tensor' object has no attribute '_lazy_read'
. I believe this is because the inputs are Tensors but not Variables which scatter_update()
requires.
How can I calculate loss for this?
解决方案
您不能使用scatter_update
,因为那是一维数据。您可能需要看一下gather_nd和scatter_nd 的工作原理。但是以下代码适用于您的问题。
import tensorflow as tf
num_actions = 3
batch_size = 2
tf.reset_default_graph()
output = tf.convert_to_tensor([[10, 20.2, 4.3],[5, 3, 8.9]])
# There's a bit of dark magic looking reshaping going here
# Essentially to get tensor a in the correct shape of indices
# gather_nd requires
a_idx = tf.reshape(tf.range(batch_size),[-1,1])
a = tf.convert_to_tensor([[1],[2]])
a_reshaped = tf.reshape(tf.concat([a_idx,a],axis=1),[-1,1,2])
r = tf.convert_to_tensor([[30.0],[15.0]])
diff = tf.gather_nd(output, a_reshaped)
loss = tf.scatter_nd(a_reshaped, r-diff, (batch_size, num_actions))
推荐阅读
- c - 如何比较两个字符数组?
- javascript - 尝试创建路由函数Vuejs时出现语法错误
- c# - HTML 5 输入日期未显示指定的值
- r - 我在哪里为 R 包中的测试指定随机种子?
- python - 在 sympy 中获取简化函数的简化类型
- google-cloud-platform - 为什么在使用 fastify 日志记录时在 GCP 中将错误记录为信息?
- javascript - 单击 Jquery JS 时按钮重新打开
- mysql - 为什么索引列上的 SELECT DISTINCT 不是瞬时的?
- jquery - 选择 2 中没有可见的垂直滚动条
- selenium - 我想使用 extentreports-cucumber4-adapter 添加日志和屏幕截图