tensorflow - Tensorflow - 减少 session.run() 开销
问题描述
我目前正在用 tensorflow 编写一个小的强化学习项目。当我分析一次训练运行时,我注意到超过 65% 的运行时间只需要对动作进行采样。
注意:在强化学习中,您必须在轨迹的每个步骤中对动作进行采样。这就是为什么不能分批取样的原因。
我在 GPU 上运行此图,因此我怀疑调用 session.run() 方法带来的开销,因此每次将数据从 RAM 复制到 GPU 会导致这种开销。
我的问题是:有没有办法减少多次调用 session.run() 方法带来的开销?是否可以在 CPU 上运行 sample_op(前向传递),并在 GPU 上运行训练?
提前谢谢了!
这是我的 RL 策略的代码:
class Policy:
def __init__(self, sess, state_size, action_size, lr, alpha_entropy, epsilon):
self.sess = sess
self.action_size = action_size
with tf.device("/cpu:0"):
with tf.variable_scope("Policy"):
self.state_ph = tf.placeholder(tf.float32, [None, state_size], name="state_ph")
self.action_ph = tf.placeholder(tf.float32, [None, action_size], name="action_ph")
self.advantage_ph = tf.placeholder(tf.float32, [None, 1], name="action_ph")
with tf.variable_scope("pi"):
self.pi, self.mean_action = self._create_model(trainable=True)
self.pi_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Policy/pi")
self.sample_op = self.pi.sample()
with tf.variable_scope("old_pi"):
self.old_pi, _ = self._create_model(trainable=False)
self.old_pi_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Policy/old_pi")
with tf.variable_scope("loss"):
prob_ratio = self.pi.prob(self.action_ph) / self.old_pi.prob(self.action_ph)
surrogate = prob_ratio * self.advantage_ph
clipped_surrogate = tf.minimum(surrogate, tf.clip_by_value(prob_ratio, 1.-epsilon, 1.+epsilon)*self.advantage_ph)
self.pi_entropy = self.pi.entropy()
tf.summary.scalar("entropy", tf.reduce_mean(self.pi_entropy))
self.loss = -tf.reduce_mean(clipped_surrogate + alpha_entropy * self.pi_entropy)
tf.summary.scalar("objective", self.loss)
with tf.variable_scope("training"):
self.gradients = tf.gradients(self.loss, self.pi_vars)
#self.gradients = [tf.clip_by_value(g, -1000, 1000) for g in self.gradients]
#self.gradients, _ = tf.clip_by_global_norm(self.gradients, GRADIENT_NORM)
grads = zip(self.gradients, self.pi_vars)
self.optimize = tf.train.AdamOptimizer(lr).apply_gradients(grads)
[tf.summary.histogram(v.name, g) for g, v in grads]
with tf.variable_scope("update_old_policy"):
self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(self.pi_vars, self.old_pi_vars)]
self.summary_op = tf.summary.merge_all(scope="Policy")
def _create_model(self, trainable):
layer_names = ["l1", "l2", "l3", "l4"]
l1 = tf.layers.Dense(32, activation="relu", name=layer_names[0], trainable=trainable, kernel_initializer = tf.initializers.he_normal(),)(self.state_ph)
l2 = tf.layers.Dense(64, activation="relu", name=layer_names[1], trainable=trainable, kernel_initializer = tf.initializers.he_normal(),)(l1)
l3 = tf.layers.Dense(32, activation="relu", name=layer_names[2], trainable=trainable, activity_regularizer= tf.contrib.layers.l2_regularizer(scale=0.001),kernel_initializer = tf.initializers.he_normal(),)(l2)
mu = tf.layers.Dense(self.action_size, activation="tanh", name=layer_names[3], trainable=trainable, kernel_initializer = tf.initializers.he_normal(),)(l3)
log_sigma = tf.Variable(initial_value=tf.fill((self.action_size,), 0.), trainable=trainable)
distribution = tfp.distributions.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(log_sigma))
tf.summary.histogram("log_sigma", log_sigma)
tf.summary.histogram("mu", mu)
for name in layer_names:
with tf.variable_scope(name, reuse=True):
tf.summary.histogram("kernel", tf.get_variable("kernel"))
tf.summary.histogram("bias", tf.get_variable("bias"))
return distribution, mu
def sample_action(self, state):
return self.sess.run([self.mean_action, self.sample_op], feed_dict={
self.state_ph: state
})
def train(self, states, actions, advantages):
_, summaries =self.sess.run([self.optimize, self.summary_op], feed_dict={
self.state_ph:states,
self.action_ph: actions,
self.advantage_ph: advantages
})
return summaries
def update_old_pi(self):
self.sess.run([self.update_oldpi_op])
分析结果:
解决方案
推荐阅读
- php - 看看我是否误解了 numberformatter 或者我是否错误地使用它来解析和验证
- gitlab-ci - 如何在 GitLab 管道的各个阶段之间保留 docker 映像实例?
- java - 如何停止警报声
- css - Wordpress Menu CSS code only works when hovering over, how can i fix
- html - 100% 的 Div 百分比不适合一行
- websocket - 从 NestJS 控制器或服务发出 websocket 事件
- java - file_paths.xml 无法从字符串资源中读取
- c# - 获取使用自定义属性标记的方法的 MethodInfo
- python - 来自两个文本列的熊猫 to_datetime
- java - IllegalStateException 和 NullPointer 异常