python - 如何将链式规则应用于多个 tf.GradientTape?
问题描述
我正在研究使用 TensorFlow 2 和 MPI 的管道模型并行性。但是我不知道在跨多个进程使用多个 tf.GradientTape 时如何应用链式规则。
这是我目前正在处理的代码:
import tensorflow as tf
from mpi4py import MPI
minibatch_size = 64
class Input(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.dense = tf.keras.layers.Dense(128, activation='relu')
def call(self, inputs, **kwargs):
x = self.flatten(inputs)
x = self.dense(x)
return x
class Block(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
self.dense_2 = tf.keras.layers.Dense(128, activation='relu')
def call(self, inputs, **kwargs):
x = self.dense_1(inputs)
x = self.dense_2(x)
return x
class Head(tf.keras.Model):
def __init__(self):
super().__init__()
self.dropout = tf.keras.layers.Dropout(0.2)
self.dense = tf.keras.layers.Dense(10, activation='softmax')
def call(self, inputs, **kwargs):
x = self.dropout(inputs)
x = self.dense(x)
return x
class Trainer:
def __init__(self,
comm,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
loss_fn: tf.keras.losses.Loss):
self._comm = comm
self._size = comm.Get_size()
self._rank = comm.Get_rank()
self._next_rank = self._rank + 1 if self._rank + 1 < self._size else MPI.PROC_NULL
self._prev_rank = self._rank - 1 if self._rank - 1 >= 0 else MPI.PROC_NULL
self._model = model
self._optimizer = optimizer
self._loss_fn = loss_fn
def _is_first_node(self) -> bool:
return self._rank == 0
def _is_last_node(self) -> bool:
return self._rank == self._size - 1
def _forward_pass(self, minibatch):
assert minibatch_size % self._size == 0
microbatch_size = minibatch_size // self._size
microbatches = tf.data.Dataset \
.from_tensor_slices(minibatch) \
.batch(microbatch_size)
predictions = []
tapes = []
losses = []
for microbatch in microbatches:
x, y = microbatch
with tf.GradientTape() as tape:
if self._is_first_node():
prediction = self._model(x)
self._comm.send(prediction, dest=self._next_rank)
elif self._is_last_node():
recvd = self._comm.recv(source=self._prev_rank)
prediction = self._model(recvd)
loss = self._loss_fn(y, prediction)
losses.append(loss)
else:
recvd = self._comm.recv(source=self._prev_rank)
prediction = self._model(recvd)
self._comm.send(prediction, dest=self._next_rank)
predictions.append(prediction)
tapes.append(tape)
return predictions, tapes, losses
def _backward_pass(self, predictions, tapes, losses):
grads = []
for i in range(self._size):
if self._is_first_node():
errors = self._comm.recv(source=self._next_rank)
grad = tapes[i].gradient(predictions[i],
self._model.trainable_weights,
output_gradients=errors)
elif self._is_last_node():
grad = tapes[i].gradient(losses[i], self._model.trainable_weights)
self._comm.send(grad, dest=self._prev_rank)
else:
errors = self._comm.recv(source=self._next_rank)
grad = tapes[i].gradient(predictions[i],
self._model.trainable_weights,
output_gradients=errors)
self._comm.send(grad, dest=self._prev_rank)
grads.append(grad)
grads = [tf.reduce_mean(grad, axis=0) for grad in grads]
self._optimizer.apply_gradients(zip(grads, self._model.trainable_weights))
def train_minibatch(self, minibatch):
predictions, tapes, losses = self._forward_pass(minibatch)
self._backward_pass(predictions, tapes, losses)
def main():
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
n_train = len(x_train)
n_minibatch = n_train // minibatch_size
x_train = tf.data.Dataset \
.from_tensor_slices((x_train, y_train)) \
.batch(minibatch_size, drop_remainder=True) \
.shuffle(len(x_train))
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
if rank == 0:
model = Input()
elif rank == size - 1:
model = Head()
else:
model = Block()
trainer = Trainer(comm, model, optimizer, loss_fn)
if rank == 0:
progbar = tf.keras.utils.Progbar(n_minibatch)
for minibatch in x_train:
trainer.train_minibatch(minibatch)
if rank == 0:
progbar.add(1)
if __name__ == '__main__':
main()
但是,运行此代码
mpirun -n 4 python main.py
产生以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Inputs to operation ReluGrad of type ReluGrad must have the same size and shape. Input 0: [128,10] != input 1: [16,128] [Op:ReluGrad]
任何专家可以告诉我如何正确地做到这一点吗?
解决方案
推荐阅读
- python - 例外:无法加载 model.bin
- api - 未处理的异常:类型 '_InternalLinkedHashMap
' 不是类型 'List 的子类型 ' 在类型转换中 - javascript - 为什么 getElementsByClassName 不选择我的所有元素?
- python - 将带有特定单词的字符串值传输到数据框 pandas、python 中的其他列
- amazon-web-services - 将线性学习器的输入类型更改为 csv
- python - 如何从这个网页抓取的 HTML 中提取某些元素
- php - 如何为自动图像优化提取图像路径及其推荐的新尺寸?
- amazon-dynamodb - 使用 DynamoDB 二级索引 AWS SDK 2 Java 异常创建 DynamoDbIndex 对象进行查询
- jquery - 数组的 HTML 元素的选择器
- c++ - Character::BeginPlay() 未被调用