python - 在 keras-rl 中定义动作值
问题描述
我在 keras-rl 中有一个自定义环境,在构造函数中有以下配置
def __init__(self, data):
#Declare the episode as the first episode
self.episode=1
#Initialize data
self.data=data
#Declare low and high as vectors with -inf values
self.low = numpy.array([-numpy.inf])
self.high = numpy.array([+numpy.inf])
self.observation_space = spaces.Box(self.low, self.high, dtype=numpy.float32)
#Define the space of actions as 3 (I want them to be 0, 1 and 2)
self.action_space = spaces.Discrete(3)
self.currentObservation = 0
self.limit = len(data)
#Initiates the values to be returned by the environment
self.reward = None
如您所见,我的代理将执行 3 个动作,根据动作,将在下面的函数 step() 中计算不同的奖励:
def step(self, action):
assert self.action_space.contains(action)
#Initiates the reward
self.reward=0
#get the reward
self.possibleGain = self.data.iloc[self.currentObservation]['delta_next_day']
#If action is 1, calculate the reward
if(action == 1):
self.reward = self.possibleGain-self.operationCost
#If action is 2, calculate the reward as negative
elif(action==2):
self.reward = (-self.possibleGain)-self.operationCost
#If action is 0, no reward
elif(action==0):
self.reward = 0
#Finish episode
self.done=True
self.episode+=1
self.currentObservation+=1
if(self.currentObservation>=self.limit):
self.currentObservation=0
#Return the state, reward and if its done or not
return self.getObservation(), self.reward, self.done, {}
问题是,如果我打印每一集的动作,它们是 0、2 和 4。我希望它们是 0、1 和 2。如何强制代理使用 keras 仅识别这 3 个动作-rl?
解决方案
我不确定为什么要给self.action_space = spaces.Discrete(3)
您操作,0,2,4
因为我无法使用您发布的代码片段重现您的错误,所以我建议您使用以下方法来定义您的操作
self.action_space = gym.spaces.Box(low=np.array([1]),high= np.array([3]), dtype=np.int)
这就是我从动作空间采样时得到的。
actions= gym.spaces.Box(low=np.array([1]),high= np.array([3]), dtype=np.int)
for i in range(10):
print(actions.sample())
[1]
[3]
[2]
[2]
[3]
[3]
[1]
[1]
[2]
[3]
希望这可以帮助!
推荐阅读
- tensorflow - 在神经网络框架中是否有前向映射/扭曲的实现?
- c++ - C++ setenv 解析其他变量
- c# - 使用 SMO 复制数据库但不传输所有数据
- jenkins - 我想在job2中使用job1工作区(使用jenkins管道cd到job1工作区)但我收到错误
- java - javafx 场景构建器滚动窗格
- ruby-on-rails - NoMethodError:未定义的方法`define_instance_method'
- maven - Wildfly 17 服务器提供的 Maven 依赖项
- node.js - 出现错误:Node Sass 找不到当前环境的绑定:Linux 64-bit with Node.js 11.x
- amazon-web-services - 如果我超过 AWS SES 上的每秒电子邮件配额会怎样?
- android - 无法将 Map<> 存储到 Firebase 存储中