首页 > 解决方案 > 带有自定义训练步骤的 TensorFlow 分布式训练

问题描述

我面临着缓慢的训练运行,我试图通过使用 Tensorflow 的StrategyAPI 来利用所有 4 个 GPU 来扩展训练过程。

我正在使用MirrorStrategy和使用experimental_distribute_dataset来对数据集进行分区。

我的训练数据的性质是稀疏矩阵和密集矩阵的混合。我正在使用生成器来构建我的数据集从数据中选择随机索引)。但是,在我当前的 TF (2.1) 版本中,生成器不支持稀疏矩阵sparse_matrix没有静态大小并且是张Ragged量。

这一点很丑陋并且是一种解决方法,但我sparse_matrix_list直接将我传递给train函数,并通过将随机索引推入内部来填充全局队列来索引它generator

现在这种方法效果很好,但是太慢了,我想尝试使用所有 GPU 进行训练。这变得更加成问题,因为我必须手动将它们sparse_matrix_list分成多个部分num_workers

然而,目前的主要问题是训练过程似乎不是并行的,并且副本 (GPU) 似乎是按顺序运行的。. 我通过验证nvidia-smi并登录该train_process功能。

我以前没有分布式培训的经验,也不知道为什么会这样,如果有人能提供更好的方法来处理这种混合数据spare,我将不胜感激。dense我目前在获取未充分利用 GPU 的数据方面面临巨大瓶颈(在 10-30% 之间波动)

def distributed_train_step(inputs, sparse_matrix_list):
    per_replica_losses = strategy.experimental_run_v2(train_process, args=(inputs, sparse_matrix_list)
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                axis=None)


def train_process(inputs, sparse_matrix_list):
    worker_id = tf.distribute.get_replica_context().replica_id_in_sync_group
    replica_batch_size = inputs.shape[0]
    slice_start = replica_batch_size * worker_id
    replica_sparse_matrix = sparse_matrix_list[slice_start:slice_start + replica_batch_size]

    return train_step(inputs, replica_sparse_matrix)


def train_step(inputs, sparse_matrix_list):
     with tf.GradientTape() as tape:
           outputs, mu, sigma, feat_out, logit = model(inputs)
           loss = K.backend.mean(custom_loss(inputs, sparse_matrix_list)

     return loss

def get_batch_data(sparse_matrix_list):
    # Queue with the random indices into the training data (List of Lists with each 
    # entry len == batch_size)
    # train_indicie is a global q
    next_batch_indicies = train_indicies.get()

    batch_sparse_list = sparse_matrix_list[next_batch_indicies]
            

dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

for batch, inputs in enumerate(dist_dataset, 1):
    # sparse_matrix_list is passed to this main "train" function from outside this module.
    batch_sparse_matrix_slice = get_batch_data(sparse_matrix_list)

    loss = distributed_train_step(inputs, batch_sparse_matrix_slice)


标签: pythontensorflowtensorflow2.0sparse-matrixtensorflow-datasets

解决方案


推荐阅读