tensorflow - 张量流图构建期间的ValueError
问题描述
我正在尝试使用 tensorflow 训练 GAN,但在图形构建过程中出现此错误:
ValueError: Input 0 of layer conv1_1 is incompatible with the layer: its rank is undefined, but the layer requires a defined rank.
我猜这是因为我的鉴别器和生成器在我调用的不同函数中。理想情况下,我希望避免将我的整个架构从方法中推出,因为这会导致一个冗长的 python 文件。
在通过 sess.run() 调用推送第一个训练示例之前,我尝试使用占位符的默认值,但这导致在训练的每个阶段都在图表中运行相同的示例(可能是因为这就是 tensorflow构建了图表)
我的训练循环代码如下。请让我知道查看生成器和鉴别器功能本身是否会有所帮助。
img1 = tf.placeholder(dtype = tf.float32)
img2 = tf.placeholder(dtype = tf.float32)
lr = tf.placeholder(dtype = tf.float32)
synthetic_imgs, synthetic_logits, mse, _ = self.gen(img1)
fake_result, fake_logits_hr, fake_feature_2 = self.disc(img1, synthetic_imgs)
ground_truth_result, ground_truth_logits_hr, truth_feature_2 = self.disc(img1, img2)
_, fake_logits_lr = self.disc_two(img1, synthetic_imgs)
_, ground_truth_logits_lr = self.disc_two(img1, img2)
a = tf.nn.sigmoid_cross_entropy_with_logits
dis_labels = tf.random.uniform((self.batch_size, 1), minval = -0.2, maxval = 0.3)
gen_labels = tf.random.uniform((self.batch_size, 1), minval = 0.75, maxval = 1.2)
dis_loss = #Discriminator Loss
gen_loss = #Generator Loss
#May want to change to -log(MSE)
d_vars = [var for var in tf.trainable_variables() if 'disc' in var.name]
g_vars = [var for var in tf.trainable_variables() if 'g_' in var.name]
dis_loss = tf.reduce_mean(dis_loss)
gen_loss = tf.reduce_mean(gen_loss)
with tf.variable_scope('optimizers', reuse = tf.AUTO_REUSE) as scope:
gen_opt = tf.train.AdamOptimizer(learning_rate = lr, name = 'gen_opt')
disc_opt = tf.train.AdamOptimizer(learning_rate = lr, name = 'dis_opt')
gen1 = gen_opt.minimize(gen_loss, var_list = g_vars)
disc1 = disc_opt.minimize(dis_loss, var_list = d_vars)
#Tensorboard code for visualizing gradients
#global_step = variable_scope.get_variable("global_step", [], trainable=False, dtype=dtypes.int64, initializer=init_ops.constant_initializer(0, dtype=dtypes.int64))
# gen_training = tf.contrib.layers.optimize_loss(gen_loss, global_step, learning_rate=lr, optimizer=gen_opt, summaries=["gradients"], variables = g_vars)
# disc_training = tf.contrib.layers.optimize_loss(dis_loss, global_step, learning_rate=lr, optimizer=disc_opt, summaries=["gradients"], variables = d_vars)
#summary = tf.summary.merge_all()
with tf.Session() as sess:
print('start session')
sess.run(tf.global_variables_initializer())
#print(tf.trainable_variables()) #Find variable corresponding to conv filter weights, which you can use for tensorboard visualization
#Code to load each training example
gen = self.pairs()
for i in range(self.num_epochs):
print(str(i+ 1) + 'th epoch')
for j in range(self.num_batches):
i_1 = None
i_2 = None
#Crektes batch
for k in range(self.batch_size):
p = next(gen)
try:
i_1 = np.concatenate((i1, self.load_img(p[0])), axis = 0)
i_2 = np.concatenate((i2, self.load_img(p[1])), axis = 0)
except Exception:
i_1 = self.load_img(p[0])
i_2 = self.load_img(p[1])
l_r = 8e-4 * (0.5)**(i//100) #Play around with this value
test, gLoss, _ = sess.run([img1, gen_loss, gen1], feed_dict = {i1 : i_1, i2 : i_2, learn_rate : l_r})
dLoss, _ = sess.run([dis_loss, disc1], feed_dict = {i1 : i_1, i2 : i_2, learn_rate : l_r})
print(test.shape)
cv2.imwrite('./saved_imgs/gan_test'+str(j)+'.png', np.squeeze(test, axis = 3)[0])
#Code to display gradients and other relevant stats on tensorboard
#Will be under histogram tab, labelled OptimizeLoss
# if j%500 == 0:
# writer = tf.summary.FileWriter(sess.graph, logdir = './tensorboard/1') #Change logdir for each run
# summary_str = sess.run(summary, feed_dict = {i1 : i_1, i2 : i_2, learn_rate : l_r})
# writer.add_summmary(summary_str, str(i)+': '+str(j))#Can change to one epoch only if necessary
# writer.flush()
if j % 12 == 0: #Prints loss statistics every 12th batch
#print('Epoch: '+str(i))
print('Generator Loss: '+str(gLoss))
print('Discriminator Loss: '+str(dLoss))
self.save_model(sess, i)
解决方案
推荐阅读
- python - SQLAlchemy - 按用户名和实体过滤
- php - 如何获取表格特定列中的项目
- python - TypeError:必须是str,而不是float(Python,返回值)
- javascript - 如何在给定 URL 的情况下获取类名的内部文本 - Javascript
- react-native - 设置权限后是否会刷新本机应用程序?(仅限 iOS)
- php - Symfony 4 DateTimeNormalizer 服务未应用
- python - **kwargs 和默认参数
- go - 带有通道的循环中的 goroutine
- amazon-web-services - 从 Amazon s3 存储桶中提取文件时随机 ssl 握手失败
- r - How can I identify the rows based on one string in a sentence