首页 > 解决方案 > 如何将链式规则应用于多个 tf.GradientTape?

问题描述

我正在研究使用 TensorFlow 2 和 MPI 的管道模型并行性。但是我不知道在跨多个进程使用多个 tf.GradientTape 时如何应用链式规则。

这是我目前正在处理的代码:

import tensorflow as tf

from mpi4py import MPI

minibatch_size = 64


class Input(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(128, activation='relu')

    def call(self, inputs, **kwargs):
        x = self.flatten(inputs)
        x = self.dense(x)
        return x


class Block(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(128, activation='relu')

    def call(self, inputs, **kwargs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        return x


class Head(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.dropout = tf.keras.layers.Dropout(0.2)
        self.dense = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, inputs, **kwargs):
        x = self.dropout(inputs)
        x = self.dense(x)
        return x


class Trainer:

    def __init__(self,
                 comm,
                 model: tf.keras.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 loss_fn: tf.keras.losses.Loss):
        self._comm = comm
        self._size = comm.Get_size()
        self._rank = comm.Get_rank()
        self._next_rank = self._rank + 1 if self._rank + 1 < self._size else MPI.PROC_NULL
        self._prev_rank = self._rank - 1 if self._rank - 1 >= 0 else MPI.PROC_NULL
        self._model = model
        self._optimizer = optimizer
        self._loss_fn = loss_fn

    def _is_first_node(self) -> bool:
        return self._rank == 0

    def _is_last_node(self) -> bool:
        return self._rank == self._size - 1

    def _forward_pass(self, minibatch):
        assert minibatch_size % self._size == 0
        microbatch_size = minibatch_size // self._size
        microbatches = tf.data.Dataset \
            .from_tensor_slices(minibatch) \
            .batch(microbatch_size)
        predictions = []
        tapes = []
        losses = []
        for microbatch in microbatches:
            x, y = microbatch
            with tf.GradientTape() as tape:
                if self._is_first_node():
                    prediction = self._model(x)
                    self._comm.send(prediction, dest=self._next_rank)
                elif self._is_last_node():
                    recvd = self._comm.recv(source=self._prev_rank)
                    prediction = self._model(recvd)
                    loss = self._loss_fn(y, prediction)
                    losses.append(loss)
                else:
                    recvd = self._comm.recv(source=self._prev_rank)
                    prediction = self._model(recvd)
                    self._comm.send(prediction, dest=self._next_rank)
            predictions.append(prediction)
            tapes.append(tape)
        return predictions, tapes, losses

    def _backward_pass(self, predictions, tapes, losses):
        grads = []
        for i in range(self._size):
            if self._is_first_node():
                errors = self._comm.recv(source=self._next_rank)
                grad = tapes[i].gradient(predictions[i],
                                         self._model.trainable_weights,
                                         output_gradients=errors)
            elif self._is_last_node():
                grad = tapes[i].gradient(losses[i], self._model.trainable_weights)
                self._comm.send(grad, dest=self._prev_rank)
            else:
                errors = self._comm.recv(source=self._next_rank)
                grad = tapes[i].gradient(predictions[i],
                                         self._model.trainable_weights,
                                         output_gradients=errors)
                self._comm.send(grad, dest=self._prev_rank)
            grads.append(grad)
        grads = [tf.reduce_mean(grad, axis=0) for grad in grads]
        self._optimizer.apply_gradients(zip(grads, self._model.trainable_weights))

    def train_minibatch(self, minibatch):
        predictions, tapes, losses = self._forward_pass(minibatch)
        self._backward_pass(predictions, tapes, losses)


def main():
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    n_train = len(x_train)
    n_minibatch = n_train // minibatch_size

    x_train = tf.data.Dataset \
        .from_tensor_slices((x_train, y_train)) \
        .batch(minibatch_size, drop_remainder=True) \
        .shuffle(len(x_train))

    optimizer = tf.keras.optimizers.Adam()
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    if rank == 0:
        model = Input()
    elif rank == size - 1:
        model = Head()
    else:
        model = Block()
    trainer = Trainer(comm, model, optimizer, loss_fn)

    if rank == 0:
        progbar = tf.keras.utils.Progbar(n_minibatch)
    for minibatch in x_train:
        trainer.train_minibatch(minibatch)
        if rank == 0:
            progbar.add(1)


if __name__ == '__main__':
    main()

但是,运行此代码

mpirun -n 4 python main.py

产生以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Inputs to operation ReluGrad of type ReluGrad must have the same size and shape.  Input 0: [128,10] != input 1: [16,128] [Op:ReluGrad]

任何专家可以告诉我如何正确地做到这一点吗?

标签: pythontensorflowkerashpc

解决方案


推荐阅读