python-3.x - 在字典中查找对象时如何解决“具有多个元素的数组的真值不明确”错误?
问题描述
我正在尝试实现一个简单的强化学习算法。基本上,代理应该使用 Q 学习从方形网格的 A 点移动到 B 点。我以前使用更简单的模型使它可以工作,但是现在我需要对其进行一些改进。基本上,我想将算法生成的 Q 值存储在名为 (self.)Q 的字典中,其中每个键是代理的状态,每个字典值是一个列表,其中 Q 值对应于该状态。状态是 State 类的对象,它具有网格矩阵作为属性。但是,当我想检查一个状态 (new_state) 是否已经在字典 self.Q 中时(参见下面的代码),我收到以下错误:
具有多个元素的数组的真值是不明确的。使用 >a.any() 或 a.all()
为什么会这样?我的代码基于这篇文章https://medium.com/@curiousily/solving-an-mdp-with-q-learning-from-scratch-deep-reinforcement-learning-for-hackers-part-1- 45d1d360c120,他们似乎没有遇到这个问题。如果认为这与状态是单独的对象这一事实有关,但我不知道如何解决这个问题。
import numpy as np
import random as rnd
from copy import deepcopy
grid_size = 4
m_A = 0 # Start coordinate in matrix
n_A = 0 # Start coordinate in matrix
m_B = grid_size - 1 # End coordinate
n_B = grid_size - 1 # End coordinate
ACTIONS = ['Right', 'Left', 'Up', 'Down']
eps = 0.1
gamma = 0.7
alpha = 1
class State:
"""Defines the current state of the agent."""
def __init__(self, grid):
self.grid = grid
def __eq__(self, other):
return isinstance(other, State) and self.grid == other.grid
def __hash__(self):
return hash(str(self.grid))
terminal_grid = np.zeros((grid_size, grid_size))
terminal_grid[m_B][n_B]
terminal_state = State(terminal_grid)
class Robot:
"""Implements agent. """
def __init__(self, row = m_A, col = n_A, cargo = False):
self.m = row # Robot position in grid (row)
self.n = col # Robot position in grid (col)
self.carry = cargo # True if robot carries cargo, False if not
self.Q = dict()
self.Q[terminal_state] = [0, 0, 0, 0]
def move_robot(self, state):
"""Moves the robot according to the given action."""
m = self.m # Current row
n = self.n # Current col
p = [] # Probability distribution
for i in range(len(ACTIONS)):
p.append(eps/4)
if self.carry is False: # If the robot is moving from A to B
Qmax = max(self.Q[state])
for i in range(len(p)):
if self.Q[state][i] == Qmax:
p[i] = 1 - eps + eps/4
break # Use if number of episodes is large
cur_env = deepcopy(state.grid)
# cur_env = state.grid
cur_env[m][n] = 0
action = choose_action(p)
if action == 'Right':
if n + 1 >= grid_size or cur_env[m][n+1] == 1:
Rew = -5 # Reward -5 if we move into wall or another agent
else:
n += 1
Rew = -1 # Reward -1 otherwise
a = 0 # Action number
elif action == 'Left':
if n - 1 < 0 or cur_env[m][n-1] == 1:
Rew = -5
else:
n -= 1
Rew = -1
a = 1
elif action == 'Up':
if m - 1 < 0 or cur_env[m-1][n] == 1:
Rew = -5
else:
m -= 1
Rew = -1
a = 2
elif action == 'Down':
if m + 1 >= grid_size or cur_env[m+1][n] == 1:
Rew = -5
else:
m += 1
Rew = -1
a = 3
m = m % grid_size
n = n % grid_size
self.m = m
self.n = n
cur_env[m][n] = 1
# print(cur_env)
new_state = State(cur_env)
if new_state not in self.Q: # Cheack if state is in dictionary
self.Q[new_state] = np.random.rand(len(ACTIONS))
return new_state, a
def choose_action(prob):
"""Defines policy to follow."""
action = np.random.choice(ACTIONS, p = prob)
return action
def episode(robot):
"""Simulation of one episode."""
# Initialize E, S
E = np.zeros((grid_size, grid_size), dtype = int)
E[m_A][n_A] = 1 # Initializes position of robot
S = State(E) # Initializes state of robot
robot.Q[S] = np.random.rand(len(ACTIONS))
count = 0
while robot.carry is False:
S_new, action_number = robot.move_robot(S)
m_new = robot.m
n_new = robot.n
if m_new != m_B or n_new != n_B:
R = -1
else:
R = 5
robot.carry = True # Picks up cargo
robot.Q[S][action_number] += alpha*(R + gamma*max(robot.Q[S_new]) -robot.Q[S][action_number])
S = S_new
# print(E)
# print()
count += 1
return count
nepisodes = []
step_list = []
def simulation():
"""Iterates through all episodes."""
r1 = Robot()
for i in range(400):
nsteps = episode(r1)
nepisodes.append(i+1)
step_list.append(nsteps)
r1.m = m_A
r1.n = n_A
print("End of episode!")
print(nsteps)
simulation()
解决方案
推荐阅读
- c++ - 接受右值引用的函数,如果未指定顺序,如何使用它两次
- javascript - 带有 API / SDK 的磁条写入器 / 编码器?
- java - 将字符串添加到扫描仪中
- visual-studio - 如何增加 250 个事件的 Visual Studio 调试会话遥测限制?
- python - 找到最接近给定集合的 NumPy 向量的子向量?
- python - 转换为字符串json后无法加载字节Json
- azure - Terraform 是否具有 Azure ARM“完整”模式?
- php - Lighthouse graphql 自定义解析器
- javascript - 如何将数组作为参数传递给使用 js2py 从 python 调用的 javascript 函数?
- here-api - 应该使用哪些选项来生成 Trailer 可访问路线?