reinforcement-learning - 将 REINFORCE 应用于 easy21
问题描述
我正在尝试在 David Silver 的 easy21 上应用 REINFORCE 算法(使用 SoftMax 策略,带有基线的未折扣 Gt),但我在实际实现中遇到了问题。与纯 MC 方法相比,产生的结果不会收敛到 Q*。以下是相关代码:
hit = True
stick = False
actions = [hit, stick]
alpha = 0.1
theta = np.random.randn(420).reshape((420,1))
def psi(state, action):
if state.player < 1 or state.player > 21:
return np.zeros((420, 1))
dealers = [int(state.dealer == x + 1) for x in range(0, 10)]
players = [int(state.player == x + 1) for x in range(0, 21)]
actions = [int(action == hit), int(action == stick)]
psi = [1 if (i == 1 and j == 1 and k == 1) else 0
for i in dealers for j in players for k in actions]
return np.array(psi).reshape((420, 1))
def Q(state, action, weight):
return np.matmul(psi(state, action).T, weight)
def softmax(state, weight):
allQ = [Q(state, a, weight) for a in actions]
probs = np.exp(allQ) / np.sum(np.exp(allQ))
return probs.reshape((2,))
def score_function(state, action, weight):
probs = softmax(state, weight)
expected_score = (probs[0] * psi(state, hit)) + (probs[1] * psi(state, stick))
return psi(state, action) - expected_score
def softmax_policy(state, weight):
probs = softmax(state, weight)
if np.random.random() < probs[1]:
return stick
else:
return hit
if __name__ == "__main__":
Q_star = np.load('Q_star.npy')
for k in range(1, ITERATIONS):
terminal = False
state = game.initialise_state()
action = softmax_policy(state, theta)
history = [state, action]
while not terminal:
state, reward = game.step(state, action)
action = softmax_policy(state, theta)
terminal = state.terminal
if terminal:
state_action_pairs = zip(history[0::3], history[1::3])
history.append(reward)
history.append(state)
Gt = sum(history[2::3])
for s, a in state_action_pairs:
advantage = Gt - Q(s, a, prev_theta)
theta += alpha * score_function(s, a, theta) * advantage
else:
history.append(reward)
history.append(state)
history.append(action)
if k % 10000 == 0:
print("MSE: " + str(round(np.sum((Q_star - generate_Q(theta)) ** 2),2)))
输出:
python reinforce.py
MSE: 288.18
MSE: 248.45
MSE: 227.08
MSE: 215.46
MSE: 207.3
MSE: 202.61
MSE: 197.82
MSE: 195.96
MSE: 194.01
更新:通过使用不同的 theta 初始化来修复代码:
theta = np.zeros((420,1))
但是当前值函数仍然不匹配 Q*(在玩家总和 = 11 处缺失峰值)
整个代码可在以下网址获得: https ://github.com/Soundpulse/easy21-rl/blob/main/reinforce.py
解决方案
推荐阅读
- php - MySQL 在 Where 条件下与 NULL 的比较
- javascript - 如何根据其中最大选项卡的最大涨幅为所有选项卡提供页面高度
- python - for 循环不会迭代函数中的字符串,但没有函数 - python 3
- javascript - 注册时弹出消息
- node.js - 在没有客户端框架的情况下使用 Webpack 和 Express.js 时如何要求资产?
- postgresql - Heroku PostgreSQL 中的图像不接受图像二进制数据并使用 \ 字符
- android - 为什么数据库是空的?
- wordpress - Wordpress 分页链接不适用于自定义查询
- wordpress - WP - 仅在帖子不存在时插入帖子
- ios - 从另一个单元格输入一个数据后如何清除文本字段中的文本