首页 > 解决方案 > TensorFlow Federated:如何调整联合数据集中的非独立同分布?

问题描述

我正在 TensorFlow Federated (TFF) 中测试一些算法。在这方面,我想在具有不同“级别”数据异质性(即非独立同分布)的同一联邦数据集上测试和比较它们。

因此,我想知道是否有任何方法可以自动或半自动方式控制和调整特定联合数据集中的非独立同分布的“级别”,例如通过 TFF API 或传统的 TF API (可能在 Dataset utils 中)。

更实用的是:例如,TFF 提供的 EMNIST 联合数据集有 3383 个客户端,每个客户端都有自己的手写字符。然而,这些本地数据集在本地示例的数量和表示的类方面似乎相当平衡(所有类或多或少都在本地表示)。如果我想要一个联合数据集(例如,从 TFF 的 EMNIST 开始),那就是:

我应该如何在 TFF 框架内继续准备具有这些特征的联合数据集?

我应该手工做所有的事情吗?或者你们中的一些人有一些建议来自动化这个过程吗?

另一个问题:在 Hsu 等人的这篇论文“Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification”中,他们利用 Dirichlet 分布来合成一组不同的客户,并使用浓度参数控制客户端之间的相同性。这似乎是一种易于调整的方式来生成具有不同异质性水平的数据集。任何关于如何在 TFF 框架内或仅在 TensorFlow (Python) 中实施此策略(或类似策略)的建议(考虑到 EMNIST 等简单数据集)也将非常有用。

十分感谢。

标签: pythontensorflowtensorflow2.0tensorflow-federated

解决方案


对于联邦学习模拟,在实验驱动程序中使用 Python 设置客户端数据集以实现所需的分布是非常合理的。在某些高层次上,TFF 处理建模数据位置(类型系统中的“放置”)和计算逻辑。尽管您发现了一些有用的库,但重新混合/生成模拟数据集并不是该库的核心。tf.data.Dataset通过操作然后将客户端数据集“推送”到 TFF 计算中直接在 python 中执行此操作似乎很简单。

标签非 IID

是的,tff.simulation.datasets.build_single_label_dataset旨在用于此目的。

它需要一个并且基本上过滤掉所有与值tf.data.Dataset不匹配的示例(假设数据集产生类似结构)。desired_labellabel_keydict

对于 EMNIST,要创建所有数据集(无论用户如何),这可以通过以下方式实现:

train_data, _ = tff.simulation.datasets.emnist.load_data()
ones = tff.simulation.datasets.build_single_label_dataset(
  train_data.create_tf_dataset_from_all_clients(),
  label_key='label', desired_label=1)
print(ones.element_spec)
>>> OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
print(next(iter(ones))['label'])
>>> tf.Tensor(1, shape=(), dtype=int32)

数据不平衡

使用tf.data.Dataset.repeat和的组合tf.data.Dataset.take可用于创建数据不平衡。

train_data, _ = tff.simulation.datasets.emnist.load_data()
datasets = [train_data.create_tf_dataset_for_client(id) for id in train_data.client_ids[:2]]
print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
>>> [93, 109]
datasets[0] = datasets[0].repeat(5)
datasets[1] = datasets[1].take(5)
print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
>>> [465, 5]

推荐阅读