python - 如何实现线性萨尔萨
问题描述
你如何在 Python 中实现“ Linear Sarsa ”?
我为那些不熟悉该算法的人提供了一个伪代码示例,以及我个人在 Python 中实现它的尝试。
线性 Sarsa 算法的伪代码示例:
任意初始化 θ
对于 {1, . . . , N} 做
s ← 第 i 集的初始状态
而状态 s 不是终端做
a ← π(s)
r ← 在状态 s 观察到动作 a 的奖励
s0 ← 在状态 s 观察到动作 a 的下一个状态
θ ← θ + α[r + γV(s0; θ) - V(s; θ)]∇θV(s; θ)
s ← s0
结束时
结束
尝试的 Python 实现
def linear_sarsa(env, max_episodes, eta, gamma, epsilon, seed=None):
random_state = np.random.RandomState(seed)
eta = np.linspace(eta, 0, max_episodes)
epsilon = np.linspace(epsilon, 0, max_episodes)
theta = np.zeros(env.n_features)
for i in range(max_episodes):
features = env.reset()
q = features.dot(theta)
gradientThetaFeatures = np.gradient(q[s]) #∇θV(s; θ)
done = False
while not done:
s2, reward, done, info = env.step(action)
action1 = policy
#Theta Value
#θ ← θ + α[r + γV(s2; θ) − V(s; θ)]∇θV(s; θ)
theta = theta + eta * (reward + (gamma * q[s2]) - q[s]) * gradientThetaFeatures
s = s2
return theta
我这样做正确吗?
解决方案
推荐阅读
- python - 如何避免为相同的命令但不同的输入创建多个函数?
- java - Firestore deadlock when using transaction.get(query)
- python - 如何从现有的数据框字符串列创建单词标记的熊猫数据框?
- delphi - 如何从 TIdHTTP get 中读取标头
- python - uint8 CIE Luv Mat 的取值范围是多少?
- jquery - jQuery 选择器格式化
- javascript - 如何使用 JavaScript 将一个输入字段中收集的数组数据传递给另一个输入字段而不丢失数组属性?
- java - com.android.builder.dexing.DexArchiveMergerException:合并 dex 档案时出错:
- aws-cli - AWS CodeDeploy :bucket 选项不得包含正斜杠 (/)
- java - 文件夹位置的字符串中的字符串?