首页 > 解决方案 > 在 TensorFlow 中联合训练多个不同的数据集

问题描述

我正在尝试创建一个由多个不同数据集更新的网络。

代码如下:

# Network
def UNet(args):
  with tf.variable_scope('UNet',reuse=args.reuse):

    input_imgs_ph=tf.placeholder(
      tf.float32,
      [None,
      int(args.input_size),
      int(args.input_size),
      3])

    # ============================================================
    # conv2d
    o_conv1=tf.layers.conv2d(
      inputs=input_imgs_ph,
      filters=100,
      kernel_size=[3,3],
      strides=1,
      padding="same",
      kernel_regularizer=tf.contrib.layers.l2_regularizer(reg))
    o_conv1=tf.nn.leaky_relu(o_conv1,alpha=0.2,name='leaky_relu_after_conv1')  
    o_conv1=tf.keras.layers.AveragePooling2D((2,2),name='avgpool_after_conv1')(o_conv1)

    # ...

    return input_imgs_ph,end_conv_r,end_conv_s

# ============================================================
# train_dataset1.py
# Train algorithm over dataset1
def train(mpisintel_d_bs_pa,args):
  # ============================================================
  with tf.variable_scope("UNet"):
    # Perform augmentation
    # ...

    # ============================================================
    # Create placeholders
    # ...

    # ============================================================
    # Construct network graph structure
    input_imgs_ph,end_conv_r,end_conv_s=networks.UNet(args)
    # ...

    # ============================================================
    # Loss function
    # ...

    # ============================================================
    # c optimizer: AdamOptimizer node
    optimizer=tf.train.AdamOptimizer(0.001).minimize(r_data_loss)

    # ============================================================
    # c init: initializer node which initialzes trainable Variables
    init=tf.global_variables_initializer()

    # ============================================================
    feed_dict={
      input_imgs_ph:cgmit_tr_3c_imgs,
      cgmit_gt_R_3c_imgs_ph:cgmit_gt_R_3c_imgs,
      cgmit_gt_S_1c_imgs_ph:cgmit_gt_S_1c_imgs,
      cgmit_mask_3c_imgs_ph:cgmit_mask_3c_imgs}

  return init,feed_dict,r_data_loss,optimizer

# ============================================================
# train_dataset2.py
# Train algorithm over dataset2
def train(mpisintel_d_bs_pa,args):
  # ============================================================
  with tf.variable_scope("UNet"):
    # Perform augmentation
    # ...

    # ============================================================
    # Create placeholders
    # ...

    # ============================================================
    # Construct network graph structure
    input_imgs_ph,end_conv_r,end_conv_s=networks.UNet(args)
    # ...

    # ============================================================
    # Loss function
    # ...

    # ============================================================
    # c optimizer: AdamOptimizer node
    optimizer=tf.train.AdamOptimizer(0.001).minimize(r_data_loss)

    # ============================================================
    # c init: initializer node which initialzes trainable Variables
    init=tf.global_variables_initializer()

    # ============================================================
    feed_dict={
      input_imgs_ph:cgmit_tr_3c_imgs,
      cgmit_gt_R_3c_imgs_ph:cgmit_gt_R_3c_imgs,
      cgmit_gt_S_1c_imgs_ph:cgmit_gt_S_1c_imgs,
      cgmit_mask_3c_imgs_ph:cgmit_mask_3c_imgs}

  return init,feed_dict,r_data_loss,optimizer


# ============================================================
# train_jointly.py

# --------------------------------------------------------------------------------
# Create iterator for dataset1 (mpisintel dataset)
path_data_mpisintel=tf.data.Dataset.from_tensor_slices(mpisintel_d_list)
path_data_mpisintel=path_data_mpisintel.prefetch(
    buffer_size=int(args.batch_size)).batch(int(args.batch_size)).repeat()
path_data_mpisintel_iter=path_data_mpisintel.make_one_shot_iterator()

# Construct train algorithm graph for dataset1
init_dense,feed_dict_dense,r_data_loss_dense,optimizer_dense=\
train_over_dense_dataset.train(mpisintel_d_bs_pa,args)

# --------------------------------------------------------------------------------
# Create iterator for dataset2 (iiw dataset)
path_data_iiw=tf.data.Dataset.from_tensor_slices(iiw_d_list)
path_data_iiw=path_data_iiw.prefetch(
    buffer_size=int(args.batch_size)).batch(int(args.batch_size)).repeat()
path_data_iiw_iter=path_data_iiw.make_one_shot_iterator()

# Construct train algorithm graph for dataset2
init_iiw,feed_dict_iiw,r_data_loss_iiw,optimizer_iiw=\
train_over_iiw_dataset.train(iiw_d_bs_pa,args)

# ============================================================
# Train network over epochs
for one_ep in range(epoch):
  # Train network over batch of dataset1
  with tf.Session() as sess_dense:
    # Initialize Variables
    sess_dense.run(init_dense)

    mpisintel_d_bs_pa=path_data_mpisintel_iter.get_next()
    loaded_input_img=load_image_from_paths(mpisintel_d_bs_pa)

    loss_val,optim=sess_dense.run(
    [r_data_loss_dense,optimizer_dense],
     feed_dict=loaded_input_img)
    # print("loss_val",loss_val)
    # loss_val 9.534523

  # ============================================================
  # Train network over batch of dataset2
  with tf.Session() as sess_iiw:
    # Initialize Variables
    sess_iiw.run(init_iiw)

    iiw_d_bs_pa=path_data_iiw_iter.get_next()
    loaded_input_img=load_image_from_paths(iiw_d_bs_pa)

    loss_val,optim=sess_iiw.run(
    [r_data_loss_iiw,optimizer_iiw],
     feed_dict=loaded_input_img)

问题:
1..我尝试多次使用以下网络

def UNet(args):
  with tf.variable_scope('UNet',reuse=args.reuse):

在每个训练算法图中

train_over_dense_dataset.train(mpisintel_d_bs_pa,args)
train_over_iiw_dataset.train(iiw_d_bs_pa,args)

这会产生错误。

2..所以我尝试使用tf.reset_default_graph()
但是很难实现我想使用的训练架构。
例如,我忍不住要创建整个图表,for loop
否则,我找不到满足我意图的方法。

标签: tensorflowmachine-learningkerasdeep-learning

解决方案


推荐阅读