首页 > 解决方案 > TensorFlow 2.0 中分布式数据集的 tf.function input_signature

问题描述

我正在尝试在 TensorFlow 2.0 中构建分布式自定义训练循环,但我无法弄清楚如何注释签名 tf.function 签名以避免回溯。

我曾尝试使用 DatasetSpec 和 TensorSpec 元组的各种组合,但我得到了各种各样的错误。

我的问题

是否可以指定一个接受批处理分布式数据集的 tf.function 输入签名?

最少的复制代码

import tensorflow as tf
from tensorflow import keras
import numpy as np


class SimpleModel(keras.layers.Layer):
    def __init__(self, name='simple_model', **kwargs):
        super(SimpleModel, self).__init__(name=name, **kwargs)
        self.w = self.add_weight(shape=(1, 1),
                                 initializer=tf.constant_initializer(5.0),
                                 trainable=True,
                                 dtype=np.float32,
                                 name='w')

    def call(self, x):
        return tf.matmul(x, self.w)


class Trainer:
    def __init__(self):
        self.mirrored_strategy = tf.distribute.MirroredStrategy()

        with self.mirrored_strategy.scope():
            self.simple_model = SimpleModel()
            self.optimizer = tf.optimizers.Adam(learning_rate=0.01)

    def train_batches(self, dataset):
        dataset_dist = self.mirrored_strategy.experimental_distribute_dataset(dataset)

        with self.mirrored_strategy.scope():
            loss = self.train_batches_dist(dataset_dist)

        return loss.numpy()

    @tf.function(input_signature=(tf.data.DatasetSpec(element_spec=tf.TensorSpec(shape=(None, 1), dtype=tf.float32)),))
    def train_batches_dist(self, dataset_dist):
        total_loss = 0.0
        for batch in dataset_dist:
            losses = self.mirrored_strategy.experimental_run_v2(
                Trainer.train_batch, args=(self, batch)
            )
            mean_loss = self.mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, losses, axis=0)

            total_loss += mean_loss
        return total_loss

    def train_batch(self, batch):
        with tf.GradientTape() as tape:
            losses = tf.square(2 * batch - self.simple_model(batch))

        gradients = tape.gradient(losses, self.simple_model.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.simple_model.trainable_weights))

        return losses


def main():
    values = np.random.sample((100, 1)).astype(np.float32)

    dataset = tf.data.Dataset.from_tensor_slices(values)
    dataset = dataset.batch(10)

    trainer = Trainer()
    for epoch in range(0, 100):
        loss = trainer.train_batches(dataset)
        print(loss / 10.0)


if __name__ == '__main__':
    main()

错误信息

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.distribute.input_lib.DistributedDataset'>

标签: pythontensorflow2.0

解决方案


推荐阅读