python - TensorFlow 2.0 中分布式数据集的 tf.function input_signature
问题描述
我正在尝试在 TensorFlow 2.0 中构建分布式自定义训练循环,但我无法弄清楚如何注释签名 tf.function 签名以避免回溯。
我曾尝试使用 DatasetSpec 和 TensorSpec 元组的各种组合,但我得到了各种各样的错误。
我的问题
是否可以指定一个接受批处理分布式数据集的 tf.function 输入签名?
最少的复制代码
import tensorflow as tf
from tensorflow import keras
import numpy as np
class SimpleModel(keras.layers.Layer):
def __init__(self, name='simple_model', **kwargs):
super(SimpleModel, self).__init__(name=name, **kwargs)
self.w = self.add_weight(shape=(1, 1),
initializer=tf.constant_initializer(5.0),
trainable=True,
dtype=np.float32,
name='w')
def call(self, x):
return tf.matmul(x, self.w)
class Trainer:
def __init__(self):
self.mirrored_strategy = tf.distribute.MirroredStrategy()
with self.mirrored_strategy.scope():
self.simple_model = SimpleModel()
self.optimizer = tf.optimizers.Adam(learning_rate=0.01)
def train_batches(self, dataset):
dataset_dist = self.mirrored_strategy.experimental_distribute_dataset(dataset)
with self.mirrored_strategy.scope():
loss = self.train_batches_dist(dataset_dist)
return loss.numpy()
@tf.function(input_signature=(tf.data.DatasetSpec(element_spec=tf.TensorSpec(shape=(None, 1), dtype=tf.float32)),))
def train_batches_dist(self, dataset_dist):
total_loss = 0.0
for batch in dataset_dist:
losses = self.mirrored_strategy.experimental_run_v2(
Trainer.train_batch, args=(self, batch)
)
mean_loss = self.mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, losses, axis=0)
total_loss += mean_loss
return total_loss
def train_batch(self, batch):
with tf.GradientTape() as tape:
losses = tf.square(2 * batch - self.simple_model(batch))
gradients = tape.gradient(losses, self.simple_model.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.simple_model.trainable_weights))
return losses
def main():
values = np.random.sample((100, 1)).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(values)
dataset = dataset.batch(10)
trainer = Trainer()
for epoch in range(0, 100):
loss = trainer.train_batches(dataset)
print(loss / 10.0)
if __name__ == '__main__':
main()
错误信息
TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.distribute.input_lib.DistributedDataset'>