首页 > 解决方案 > 为什么 tf.get_variable('test') 返回一个名为 test_1 的变量?

问题描述

我创建了一个 tensorflow 变量tf.Variable。我想知道为什么如果我tf.get_variable用相同的名称调用没有引发异常并且使用递增的名称创建一个新变量?

import tensorflow as tf

class QuestionTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')

    def test_variable(self):
        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertEqual(b.name, "test_1:0")

        self.assertNotEqual(a, b, msg='`a` is not `b`')

        with self.assertRaises(ValueError) as ecm:
            tf.get_variable('test', shape=(), trainable=False)
        exception = ecm.exception
        self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")

标签: pythontensorflow

解决方案


这是因为这tf.Variable是一种低级方法,它将创建的变量存储在 GLOBALS(或 LOCALS)集合中,同时tf.get_variable通过将它们存储在变量存储中来记录它创建的变量。

当您第一次调用tf.Variable时,创建的变量不会添加到变量存储中,让您认为没有"test"创建具有名称的变量。

因此,当您稍后调用tf.get_variable("test")它时,它会查看变量存储区,发现其中没有带有名称"test"的变量。
它将因此调用tf.Variable,这将创建一个具有递增名称的变量,该名称"test_1"存储在键下的变量存储中"test"

import tensorflow as tf

class AnswerTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')    

    def test_variable_answer(self):
        """Using the default variable scope"""
        # Let first check the __variable_store and the GLOBALS collections.
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")
        self.assertListEqual(tf.global_variables(), [],
                             "No global variables")

        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")
        self.assertListEqual(tf.get_collection(("__variable_store",)), [],
                             "No variable store.")
        self.assertListEqual(tf.global_variables(), [a],
                             "but `a` is in global variables.")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertNotEqual(a, b, msg='`a` is not `b`')
        self.assertEqual(b.name, "test_1:0", msg="`b`'s name is not 'test'.")
        self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
                        "There is now a variable store.")
        var_store = tf.get_collection(("__variable_store",))[0]
        self.assertDictEqual(var_store._vars, {"test": b},
                             "and variable `b` is in it.")
        self.assertListEqual(tf.global_variables(), [a, b],
                             "while `a` and `b` are in global variables.")

        with self.assertRaises(ValueError) as exception_context_manager:
            tf.get_variable('test', shape=(), trainable=False)
        exception = exception_context_manager.exception
        self.assertStartsWith(str(exception),
                              "Variable test already exists, disallowed.")

使用显式变量范围时也是如此。

    def test_variable_answer_with_variable_scope(self):
        """Using now a variable scope"""
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")

        with tf.variable_scope("my_scope") as scope:
            self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0, 
                            "There is now a variable store.")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "but with variable in it.")

            a = tf.Variable(0., trainable=False, name='test')
            self.assertEqual(a.name, "my_scope/test:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "Still no variable in the store.")


            b = tf.get_variable('test', shape=(), trainable=False)
            self.assertEqual(b.name, "my_scope/test_1:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(
                var_store._vars, {"my_scope/test": b},
                "`b` is in the store, but notice the difference between its name and its key in the store.")

            with self.assertRaises(ValueError) as exception_context_manager:
                tf.get_variable('test', shape=(), trainable=False)
            exception = exception_context_manager.exception
            self.assertStartsWith(str(exception),
                                  "Variable my_scope/test already exists, disallowed.")

推荐阅读