python - 为什么 tf.get_variable('test') 返回一个名为 test_1 的变量?
问题描述
我创建了一个 tensorflow 变量tf.Variable
。我想知道为什么如果我tf.get_variable
用相同的名称调用没有引发异常并且使用递增的名称创建一个新变量?
import tensorflow as tf
class QuestionTest(tf.test.TestCase):
def test_version(self):
self.assertEqual(tf.__version__, '1.10.1')
def test_variable(self):
a = tf.Variable(0., trainable=False, name='test')
self.assertEqual(a.name, "test:0")
b = tf.get_variable('test', shape=(), trainable=False)
self.assertEqual(b.name, "test_1:0")
self.assertNotEqual(a, b, msg='`a` is not `b`')
with self.assertRaises(ValueError) as ecm:
tf.get_variable('test', shape=(), trainable=False)
exception = ecm.exception
self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")
解决方案
这是因为这tf.Variable
是一种低级方法,它将创建的变量存储在 GLOBALS(或 LOCALS)集合中,同时tf.get_variable
通过将它们存储在变量存储中来记录它创建的变量。
当您第一次调用tf.Variable
时,创建的变量不会添加到变量存储中,让您认为没有"test"
创建具有名称的变量。
因此,当您稍后调用tf.get_variable("test")
它时,它会查看变量存储区,发现其中没有带有名称"test"
的变量。
它将因此调用tf.Variable
,这将创建一个具有递增名称的变量,该名称"test_1"
存储在键下的变量存储中"test"
。
import tensorflow as tf
class AnswerTest(tf.test.TestCase):
def test_version(self):
self.assertEqual(tf.__version__, '1.10.1')
def test_variable_answer(self):
"""Using the default variable scope"""
# Let first check the __variable_store and the GLOBALS collections.
self.assertListEqual(tf.get_collection(("__variable_store",)), [],
"No variable store.")
self.assertListEqual(tf.global_variables(), [],
"No global variables")
a = tf.Variable(0., trainable=False, name='test')
self.assertEqual(a.name, "test:0")
self.assertListEqual(tf.get_collection(("__variable_store",)), [],
"No variable store.")
self.assertListEqual(tf.global_variables(), [a],
"but `a` is in global variables.")
b = tf.get_variable('test', shape=(), trainable=False)
self.assertNotEqual(a, b, msg='`a` is not `b`')
self.assertEqual(b.name, "test_1:0", msg="`b`'s name is not 'test'.")
self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
"There is now a variable store.")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(var_store._vars, {"test": b},
"and variable `b` is in it.")
self.assertListEqual(tf.global_variables(), [a, b],
"while `a` and `b` are in global variables.")
with self.assertRaises(ValueError) as exception_context_manager:
tf.get_variable('test', shape=(), trainable=False)
exception = exception_context_manager.exception
self.assertStartsWith(str(exception),
"Variable test already exists, disallowed.")
使用显式变量范围时也是如此。
def test_variable_answer_with_variable_scope(self):
"""Using now a variable scope"""
self.assertListEqual(tf.get_collection(("__variable_store",)), [],
"No variable store.")
with tf.variable_scope("my_scope") as scope:
self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
"There is now a variable store.")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(var_store._vars, {},
"but with variable in it.")
a = tf.Variable(0., trainable=False, name='test')
self.assertEqual(a.name, "my_scope/test:0")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(var_store._vars, {},
"Still no variable in the store.")
b = tf.get_variable('test', shape=(), trainable=False)
self.assertEqual(b.name, "my_scope/test_1:0")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(
var_store._vars, {"my_scope/test": b},
"`b` is in the store, but notice the difference between its name and its key in the store.")
with self.assertRaises(ValueError) as exception_context_manager:
tf.get_variable('test', shape=(), trainable=False)
exception = exception_context_manager.exception
self.assertStartsWith(str(exception),
"Variable my_scope/test already exists, disallowed.")
推荐阅读
- java - 为什么我的第二种方法的模拟会影响第一种方法?
- azure - 从 Azure 表存储中获取实体(限制 1)
- json - 无法使用 MasterPage 访问 webform aspx 中的隐藏字段值
- symfony - Symfony:扩展实体而不创建表
- blazor - 如何在 blazor Web 程序集中运行捆绑的 javascript 文件?
- php - SMTP 通知:在检查是否已连接时捕获到 EOF
- node.js - WSL 和 WSL2 上的 NodeJs 服务器无法在浏览器上访问
- php - 未定义变量:Blade Laravel 7.x 中的横幅
- java - 需要特定时间格式的 XMLGregorianCalendar| 爪哇
- python - 将文本格式的 PNG 转换回文件对象