首页 > 解决方案 > 施工后坚持 GradientTape

问题描述

是否可以在构建后设置持久性GradientTape

用例是当我有不受我控制的代码时,我将非持久磁带传递给我,我必须使用它们之间的依赖关系获取两个渐变。

# I don't control this
with tf.GradientTape() as tape:
    y = f(x)
    z = g(y)

# I control this
dzdx = tape.gradient(z, x)
result = tape.gradient(z, y, output_gradients=dzdx)  # not persistent

我考虑过tape._persistent在调用之前简单设置的可能性gradient,但是持久性在构造时一直传递给 C++ 代码,这可能不喜欢与 Python 代码不一致

def push_new_tape(persistent=False, watch_accessed_variables=True):
  """Pushes a new tape onto the tape stack."""
  tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
  return Tape(tape)

最终

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class GradientTape {
 public:
  // If `persistent` is true, GradientTape will not eagerly delete backward
  // functions (and hence the tensors they keep alive). Instead, everything
  // is deleted in ~GradientTape. Persistent GradientTapes are useful when
  // users want to compute multiple gradients over the same tape.
  explicit GradientTape(bool persistent) : persistent_(persistent) {}
  ~GradientTape() {
    for (const auto& pair : op_tape_) {
      pair.second.backward_function_deleter(pair.second.backward_function);
    }
  }

  // ...

  bool IsPersistent() const { return persistent_; }

IsPersistent的唯一公共接口在哪里persistent_。我不知道具体persistent_是如何使用的,这听起来像是等待中的内存泄漏。

标签: pythonc++tensorflow

解决方案


推荐阅读