deep-learning - 为什么我在 cartpole 上的 DQN 不收敛?
问题描述
这是我的 DQN 实现的代码。我检查了许多其他人的存储库中的代码,但找不到任何差异。我想 learn() 中有一些错误,但我找不到任何差异。我在 pytorch 官方网站上查看了 DQN 代码,我查看了他们的动作、奖励、状态、下一个状态完全没问题。
这是我的代理类代码。
class DQNAgent():
def __init__(self, net, capacity, n_actions, eps_start, eps_end, eps_decay, batch_size, gamma, lr):
self.net = net
self.target_net = copy.deepcopy(self.net)
self.buffer = ReplayBuffer(capacity)
self.n_actions = n_actions
self.device = next(net.parameters()).device
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.batch_size = batch_size
self.gamma = gamma
self.sample_step = 0 # for decaying epsilon-greedy policy
self.loss_fn = nn.SmoothL1Loss()
self.optim = torch.optim.Adam(self.net.parameters(), lr=lr)
def store_transition(self, state, action, next_state, reward):
self.buffer.push(state, action, next_state, reward)
def select_action(self, state):
'''
Returns:
- action: shape [bs, ]
'''
# Decaying epsilon
eps = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1 * self.sample_step / self.eps_decay)
self.sample_step += 1
if random.random() > eps: # greedy
self.net.eval()
with torch.no_grad():
return self.net(state).argmax(-1).item()
else: # random
return torch.randint(self.n_actions, (1,)).item()
def learn(self):
self.net.train()
self.target_net.eval()
transitions = self.buffer.sample(self.batch_size)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda x: x is not None, batch.next_state)), device=self.device, dtype=torch.bool)
non_final_next_states = torch.tensor([state for state in batch.next_state if state is not None], dtype=torch.float32, device=self.device)
states = torch.stack(batch.state).to(self.device) # [bs, input_dim]
rewards = torch.tensor(batch.reward).to(self.device) # [bs, ]
actions = torch.tensor(batch.action).unsqueeze(-1).to(self.device) # [bs, ] --> [bs, 1]
Q = self.net(states).gather(1, actions) # [bs, n_actions] --> [bs, 1]
next_state_values = torch.zeros(self.batch_size, device=device)
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * self.gamma) + rewards
self.optim.zero_grad()
loss = self.loss_fn(Q, expected_state_action_values.unsqueeze(1))
loss.backward()
self.optim.step()
return loss.detach().item()
def update_target_net(self):
self.target_net.load_state_dict(self.net.state_dict())
这是我的训练循环代码。
reward_history = []
for episode_i in tqdm(range(max_episodes)):
s = env.reset()
if np.mean(reward_history[-5:]) >= max_steps*0.9:
break
ep_reward = 0
while True:
s = torch.tensor(s, dtype=torch.float32, device=device)
a = agent.select_action(s)
new_s, r, done, info = env.step(a)
agent.store_transition(s, a, new_s, r)
if done:
reward_history.append(ep_reward)
break
s = new_s
ep_reward += 1
if len(agent.buffer) >= agent.batch_size:
loss = agent.learn()
print(f'Episode {episode_i} | Reward: {ep_reward}')
if episode_i % target_update_intv == 0:
agent.update_target_net()
解决方案
推荐阅读
- python-3.x - 是否可以使用 Python3 制作卸载软件?
- python - 如何使用另一个字典作为 Python 中的键来过滤字典?
- google-cloud-platform - Stackdriver 未能根据自定义指标创建警报
- python - 是否有使用相应的脚本 Python 自动旋转 .STEP 文件的解决方案?
- react-native - 与 React Native App 集成代码编辑器
- c++ - 内联函数是否在它们内联的函数中保留它们的上下文
- android - 如何从数组列表中获取数据并在 Recyclerview 适配器中显示该数据?
- reactjs - ReactJS 的圆形单选选项
- join - 如何针对 2 个非常大的数据帧优化 pyspark approxSimilarityJoin
- html - 样式问题 - CSS Ignoring white-space: nowrap;