python - tensorflow while循环减慢
问题描述
问题是一个 tensorflow while 循环(tf.while_loop)随着时间的推移而减慢。该循环应该返回一些矩阵。我通过字典提供所有输入。
我知道这个问题很可能是由于一遍又一遍地添加操作而污染了图表。我是 TF 初学者,对我来说,这里的图表受到污染的原因并不明显。非常感谢任何帮助。
def predict(self, actions, ...):
feed_dict = {
self.agent.actions: actions.reshape(-1, self.kwargs["dim_actions"]),
...
}
states_mu, states_var = self.session.run(self.agent.predict_states(), feed_dict=feed_dict)
return states_mu, states_var
def predict_states(self):
...
def loop_cond(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
return i < self.episode_length
def loop_body(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
state_mu_i = state_mus[-1][None, :]
...
state_var_tf = state_vars_tf[-1][None, :, :]
#Some math operations
...
new_state_mu = state_mu_i + delta_mu
new_state_var = state_var_i + delta_var + inp_out_cov
new_mu_tf, new_var_tf, inp_tf_cov = some_transform(
new_state_mu, ....)
state_mus = tf.concat([state_mus, new_state_mu], 0)
...
state_vars_tf = tf.concat([state_vars_tf, new_var_tf], 0)
i += 1
return i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov
loop_step = tf.constant(0, tf.int32)
init_mus_tf, init_vars_tf, inp_tf_cov = some_transform(
self.state_mu, self.state_var, self.dim_angles)
loop_vars = [
loop_step,
self.state_mu,
self.state_var,
init_mus_tf,
init_vars_tf,
inp_tf_cov]
shapes = [loop_step.get_shape(),
tf.TensorShape([None, self.dim_states]),
tf.TensorShape([None, self.dim_states, self.dim_states]),
tf.TensorShape([None, self.dim_states_tf]),
tf.TensorShape([None, self.dim_states_tf, self.dim_states_tf]),
inp_tf_cov.get_shape()]
_, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov = tf.while_loop(
loop_cond,
loop_body,
loop_vars=loop_vars,
shape_invariants=shapes)
return state_mus_tf[1:], state_vars_tf[1:]
该循环被多次调用。它在一次运行中减慢,即在每次迭代之后,甚至在重复调用之后更慢。每次运行的迭代速度从上次运行结束的地方开始。例如,在第一次运行开始时,每次迭代需要 1 秒,在第一次运行结束时,每次迭代需要 3 秒。在第二次运行开始时,每次迭代需要 3 秒,......直到让它运行变得不可行(例如,每次迭代需要 100 秒)。
解决方案
代码看起来大部分都很好,但是您应该predict_states
只调用一次,当您创建类的实例时(或在其他一些初始化步骤中),并将返回值存储在类属性中。例如:
def __init__(self, ...):
# ...
self.states_mu_tf, self.states_var_tf = self.agent.predict_states()
然后,您使用这些属性predict
:
states_mu, states_var = self.session.run((self.states_mu_tf, self.states_var_tf),
feed_dict=feed_dict)
这样您就不会在图中重新创建操作。
推荐阅读
- oracle - 是否有用于在 oracle erp 中重新分配采购订单/申请的 API?
- amazon-web-services - 有没有办法为在 AWS RDS 数据库实例中创建的每个数据库单独指定加密密钥(AWS KMS 密钥)?
- firebase-realtime-database - 无法在我的项目中获取用户访问令牌 facebook(使用 firebase),或者可以获取但无法使其工作
- python - 以最佳性能将 Pandas 数据帧转换为 csv
- asp.net-core - 当 Azure 应用服务横向扩展我的 Web API 时,SemaphoreSlim 会保护代码不同时运行两次吗?
- php - 当使用父子 foreach 循环添加新数组时,在子 foreach 循环中保留选中的复选框 - Laravel livewire
- firebase - 使用外部 IdP (github) 登录到 Firebase 应用程序后存储用户的刷新令牌
- php - php eval() 获取对象的属性
- sql - 如何在一个点(lon,lat)postgis附近快速获得13条记录线
- javascript - 从 DIV 中的网页和 JavaScript 中的文件获取 base64_decode($artwork) 并比较它们