首页 > 解决方案 > 如何在 TensorFlow 2 中重置初始化

问题描述

如果我在初始化 a 后尝试更改 TensorFlow 2 中的并行性tf.Variable

import tensorflow as tf
_ = tf.Variable([1])
tf.config.threading.set_inter_op_parallelism_threads(1)

我收到一个错误

RuntimeError:初始化后无法修改操作间并行度。

我理解为什么会这样,但它(可能还有其他因素)导致我的测试相互干扰。例如

def test_model():  # this test
   v = tf.Variable([1])
   ...

def test_threading():  # is breaking this test
   tf.config.threading.set_inter_op_parallelism_threads(1)
   ...

如何重置 TensorFlow 状态以便设置线程?

标签: pythontensorflowpytesttensorflow2.0

解决方案


这是可以通过“hacky”方式实现的。但我建议以正确的方式执行此操作(即在开始时设置配置)。

import tensorflow as tf
from tensorflow.python.eager import context

_ = tf.Variable([1])

context._context = None
context._create_context()

tf.config.threading.set_inter_op_parallelism_threads(1)

编辑:一开始就设置配置是什么意思,

import tensorflow as tf
from tensorflow.python.eager import context

tf.config.threading.set_inter_op_parallelism_threads(1)
_ = tf.Variable([1])

但在某些情况下,您不能总是这样做。仅指出在tf. 因此,如果您的情况不允许您tf.config在一开始就进行修复,则必须tf.eager.context按照上面的解决方案重新设置。


推荐阅读