首页 > 解决方案 > 保存客户端联合数据集的正确方法

问题描述

我想TFF使用数据集训练两个独立的模型emnist。每个模型都应该在1000从数据集中随机抽取的不同参与者上进行训练。

下面的代码

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

participants_ids = np.random.choice(a=emnist_train.client_ids, 
                                    size=1000,
                                    replace=False)

federated_dataset = 
        [data_train.create_tf_dataset_for_client(i) for i in participants_ids]

nested_dataset = tf.data.Dataset.from_tensor_slices(federated_dataset)

尝试保存数据集

tf.data.experimental.save(nested_dataset, 'model_dataset')

生成以下警告。但是,保存完成。

E tensorflow/core/framework/dataset.cc:89] The Encode() method is not implemented for DatasetVariantWrapper objects.

加载数据集并尝试检查其内容时会出现问题

dataset = tf.data.experimental.load('model_dataset', 
                      element_spec= 
                      DatasetSpec(collections.OrderedDict([
                         ('label', TensorSpec(shape=(), dtype=tf.int32)),
                         ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32))]), 
                      TensorShape([])

# verifying elements
for example in dataset:
        print(example)

下面的错误

tensorflow.python.framework.errors_impl.DataLossError: Unable to parse tensor from stored proto.

尝试其他方法,例如pickle.dumpand np.save,都导致以下错误

tensorflow.python.framework.errors_impl.InternalError: Tensorflow type 21 not convertible to numpy dtype.

有什么好方法可以保存新创建的数据集吗?

标签: tensorflowtensorflow-datasetstensorflow-federated

解决方案


与其保存数据集的数据集,不如保存采样的客户端 ID 并在加载时构建数据集?

# Create a dataset of the participating IDs.
id_ds = tf.data.Dataset.from_tensor_slices(participants_ids)
tf.data.experimental.save(id_ds, '/tmp/id_dataset')

# Loaded the dataset back later.
loaded_ds = tf.data.experimental.load(
  '/tmp/id_dataset',
  element_spec=tf.TensorSpec(shape=[], dtype=tf.string))

# Create a federated dataset that yield (client_id, dataset).
federated_dataset = loaded_ds.map(
    lambda id: (id, emnist_train.serializable_dataset_fn(id)))
print(f'Loaded dataset with {tf.data.Dataset.cardinality(federated_dataset)} clients')
>>> Loaded dataset with 1000 clients.

print(f'Dataset element types: {next(iter(federated_dataset))[1].element_spec}')
>>> Dataset element types: OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

for id, dataset in federated_dataset.take(5):
  print(f'Client [{id}] has {sum(1 for _ in iter(dataset))} examples')
>>> Client [b'f2174_61'] has 106 examples
>>> Client [b'f1378_08'] has 99 examples
>>> Client [b'f1550_34'] has 106 examples
>>> Client [b'f3817_22'] has 106 examples
>>> Client [b'f1000_45'] has 109 examples

tf.data.Dataset.map然后可以通过替换来创建扁平数据集tf.data.Dataset.flat_map

flat_dataset = loaded_ds.flat_map(
    lambda id: emnist_train.serializable_dataset_fn(id))

print(f'Dataset element types: {next(iter(federated_dataset))[1].element_spec}')
>>> Dataset element types: OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

print(f'Flat dataset has {sum(1 for _ in flat_dataset):,} examples.')
>>> Flat dataset has 101,619 examples.

推荐阅读