python - 将展开的 GAN 更新到 TF2
问题描述
我正在尝试使用示例代码实现此处描述的展开 GAN 模型。但是,它是使用 TF1 实现的,我一直在尽力更新它,但我对 python 和 TF 相对较新(过去大约 6 个月才使用它)。
我似乎无法完成的行(目前,可能还有更多)是这一行:
gen_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
这些都返回空列表,我看不到我缺少什么。即使没有指定范围,get_collection()
返回[]
. 早些时候,我们将生成器和鉴别器都定义为范围,如下所示:
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
with tf.compat.v1.variable_scope("generator"):
h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
x = slim.fully_connected(h, output_dim, activation_fn=None)
return x
def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
with tf.compat.v1.variable_scope("discriminator", reuse=reuse):
h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
log_d = slim.fully_connected(h, 1, activation_fn=None)
return log_d
范围的定义有问题吗?
这是我更新的完整代码,以防我在其他地方遗漏了什么:
%pylab inline
from collections import OrderedDict
import tensorflow as tf
import tensorflow_probability as tfp
ds = tfp.distributions
# slim = tf.contrib.slim
import tf_slim as slim
from keras.optimizers import Adam
try:
from moviepy.video.io.bindings import mplfig_to_npimage
import moviepy.editor as mpy
generate_movie = True
except:
print("Warning: moviepy not found.")
generate_movie = False
def remove_original_op_attributes(graph):
"""Remove _original_op attribute from all operations in a graph."""
for op in graph.get_operations():
op._original_op = None
def graph_replace(*args, **kwargs):
"""Monkey patch graph_replace so that it works with TF 1.0"""
remove_original_op_attributes(tf.get_default_graph())
return _graph_replace(*args, **kwargs)
def extract_update_dict(update_ops):
"""Extract variables and their new values from Assign and AssignAdd ops.
Args:
update_ops: list of Assign and AssignAdd ops, typically computed using Keras' opt.get_updates()
Returns:
dict mapping from variable values to their updated value
"""
name_to_var = {v.name: v for v in tf.compat.v1.global_variables()}
updates = OrderedDict()
for update in update_ops:
var_name = update.op.inputs[0].name
var = name_to_var[var_name]
value = update.op.inputs[1]
if update.op.type == 'Assign':
updates[var.value()] = value
elif update.op.type == 'AssignAdd':
updates[var.value()] = var + value
else:
raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
return updates
def sample_mog(batch_size, n_mixture=8, std=0.01, radius=1.0):
thetas = np.linspace(0, 2 * np.pi, n_mixture)
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
cat = ds.Categorical(tf.zeros(n_mixture))
comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())]
data = ds.Mixture(cat, comps)
return data.sample(batch_size)
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
with tf.compat.v1.variable_scope("generator"):
h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
x = slim.fully_connected(h, output_dim, activation_fn=None)
return x
def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
with tf.compat.v1.variable_scope("discriminator", reuse=reuse):
h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
log_d = slim.fully_connected(h, 1, activation_fn=None)
return log_d
params = dict(
batch_size=512,
disc_learning_rate=1e-4,
gen_learning_rate=1e-3,
beta1=0.5,
epsilon=1e-8,
max_iter=25000,
viz_every=5000,
z_dim=256,
x_dim=2,
unrolling_steps=5,
)
tf.compat.v1.reset_default_graph()
data = sample_mog(params['batch_size'])
noise = ds.Normal(tf.zeros(params['z_dim']),
tf.ones(params['z_dim'])).sample(params['batch_size'])
# Construct generator and discriminator nets
# with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=1.4)): ## old
with slim.arg_scope([slim.fully_connected], weights_initializer=tf.keras.initializers.Orthogonal(gain=1.4)):
samples = generator(noise, output_dim=params['x_dim'])
real_score = discriminator(data)
fake_score = discriminator(samples, reuse=True)
# Saddle objective
loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.cast(real_score, dtype=tf.float32), labels=tf.cast(tf.ones_like(real_score), dtype=tf.float32)) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.cast(fake_score, dtype=tf.float32), labels=tf.cast(tf.zeros_like(fake_score), dtype=tf.float32)))
gen_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
# Vanilla discriminator update
d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
# updates = d_opt.get_updates(disc_vars, [], loss) ## old
updates = d_opt.get_updates(loss, [])
d_train_op = tf.group(*updates, name="d_train_op")
### I HAVE NOT UPDATED BEYOND THIS POINT ###
# Unroll optimization of the discrimiantor
if params['unrolling_steps'] > 0:
# Get dictionary mapping from variables to their update value after one optimization step
update_dict = extract_update_dict(updates)
cur_update_dict = update_dict
for i in xrange(params['unrolling_steps'] - 1):
# Compute variable updates given the previous iteration's updated variable
cur_update_dict = graph_replace(update_dict, cur_update_dict)
# Final unrolled loss uses the parameters at the last time step
unrolled_loss = graph_replace(loss, cur_update_dict)
else:
unrolled_loss = loss
# Optimize the generator on the unrolled loss
g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
解决方案
的实施get_collection
:
def get_collection(key, scope=None):
"""Wrapper for `Graph.get_collection()` using the default graph.
See `tf.Graph.get_collection`
for more details.
Args:
key: The key for the collection. For example, the `GraphKeys` class contains
many standard names for collections.
scope: (Optional.) If supplied, the resulting list is filtered to include
only items whose `name` attribute matches using `re.match`. Items without
a `name` attribute are never returned if a scope is supplied and the
choice or `re.match` means that a `scope` without special tokens filters
by prefix.
Returns:
The list of values in the collection with the given `name`, or
an empty list if no value has been added to that collection. The
list contains the values in the order under which they were
collected.
@compatibility(eager)
Collections are not supported when eager execution is enabled.
@end_compatibility
"""
return get_default_graph().get_collection(key, scope)
它看起来像在这段代码中,key
并且scope
参数被交换了。如果您提供"generator"
或"discriminator"
作为key
没有范围即;
gen_vars = tf.compat.v1.get_collection("generator")
disc_vars = tf.compat.v1.get_collection("discriminator")
你应该得到结果(我能够使用 Tensorflow 2.2.0 在本地重现)。我无法完全确定的唯一问题是,在提供时scope
,该函数再次返回一个空列表,而不管您提供的范围值如何。例如,tf.compat.v1.GLOBAL_VARIABLES
应该返回所有内容,但事实并非如此:
gen_vars = tf.compat.v1.get_default_graph().get_collection('generator', tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) # returns []
gen_vars = tf.compat.v1.get_default_graph().get_collection('generator', tf.compat.v1.GraphKeys.GLOBAL_VARIABLES) # returns []
disc_vars = tf.compat.v1.get_collection('generator') # returns a list of tensors
更新
看起来即使在上下文管理器中创建变量也不会将它们添加到图形集合中。我必须在各自的函数中调用tf.compat.v1.add_to_collection('generator', x)
和tf.compat.v1.add_to_collection('discriminator', log_d)
来获得这些结果。
更新#2
我四处搜索,似乎没有上下文管理器可以让您将其中声明的变量添加到 Tensorflow 集合中。不过,为了这个答案的完整性,我已经实现了一个:
from contextlib import contextmanager
@contextmanager
def collection_scope(collection_name):
import inspect
from tensorflow.python.framework.ops import EagerTensor
collection = tf.compat.v1.get_collection_ref(collection_name)
yield
# this is a bit of a hack, but it works...
f = inspect.currentframe().f_back.f_back
# only take variables which were declared within the context manager
tf_variables = set([val.ref() for val in f.f_locals.values() if isinstance(val, EagerTensor)]) - \
set([val.ref() for val in f.f_back.f_locals.values() if isinstance(val, EagerTensor)])
collection.extend(tf_variables)
然后,您可以将它放在您的函数中,以代替变量范围 ( tf.compat.v1.variable_scope
) 上下文管理器。例如,而不是:
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
with tf.compat.v1.variable_scope('generator'):
h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
x = slim.fully_connected(h, output_dim, activation_fn=None)
return x
请执行下列操作:
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
with collection_scope('generator'):
h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
x = slim.fully_connected(h, output_dim, activation_fn=None)
return x
通过此更改,在上下文管理器范围内声明的所有张量都将添加到集合“生成器”中 -tf.compat.v1.get_collection('generator')
将返回正确的张量列表。
推荐阅读
- ios - 使用 UITableView 的 1 行中的 2 列
- java - 为什么调用 Statement.close() 不会立即释放 Statement 对象创建的 ResultSet 对象?
- arrays - NumPy Array: Minesweeper - 替换随机项
- whatsapi - 使用点击聊天发送 whatsapp 消息
- c++ - SDL2 图像 - IMG_Init() 返回 0,IMG_GetError() 为空
- c++ - 如何创建一个接受二维数组的构造函数
- reactjs - 反应组件更新/渲染 - 未调用 componentDidUpdate
- javascript - 渲染之间调试数组的不同符号 React
- ios - 如何为自定义单元格内的按钮设置侦听器以快速获取文本字段数据?
- apache-spark - 有没有办法在 pyspark 中获取列数据类型?