tensorflow - 保存客户端联合数据集的正确方法
问题描述
我想TFF
使用数据集训练两个独立的模型emnist
。每个模型都应该在1000
从数据集中随机抽取的不同参与者上进行训练。
下面的代码
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
participants_ids = np.random.choice(a=emnist_train.client_ids,
size=1000,
replace=False)
federated_dataset =
[data_train.create_tf_dataset_for_client(i) for i in participants_ids]
nested_dataset = tf.data.Dataset.from_tensor_slices(federated_dataset)
尝试保存数据集
tf.data.experimental.save(nested_dataset, 'model_dataset')
生成以下警告。但是,保存完成。
E tensorflow/core/framework/dataset.cc:89] The Encode() method is not implemented for DatasetVariantWrapper objects.
加载数据集并尝试检查其内容时会出现问题
dataset = tf.data.experimental.load('model_dataset',
element_spec=
DatasetSpec(collections.OrderedDict([
('label', TensorSpec(shape=(), dtype=tf.int32)),
('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32))]),
TensorShape([])
# verifying elements
for example in dataset:
print(example)
下面的错误
tensorflow.python.framework.errors_impl.DataLossError: Unable to parse tensor from stored proto.
尝试其他方法,例如pickle.dump
and np.save
,都导致以下错误
tensorflow.python.framework.errors_impl.InternalError: Tensorflow type 21 not convertible to numpy dtype.
有什么好方法可以保存新创建的数据集吗?
解决方案
与其保存数据集的数据集,不如保存采样的客户端 ID 并在加载时构建数据集?
# Create a dataset of the participating IDs.
id_ds = tf.data.Dataset.from_tensor_slices(participants_ids)
tf.data.experimental.save(id_ds, '/tmp/id_dataset')
# Loaded the dataset back later.
loaded_ds = tf.data.experimental.load(
'/tmp/id_dataset',
element_spec=tf.TensorSpec(shape=[], dtype=tf.string))
# Create a federated dataset that yield (client_id, dataset).
federated_dataset = loaded_ds.map(
lambda id: (id, emnist_train.serializable_dataset_fn(id)))
print(f'Loaded dataset with {tf.data.Dataset.cardinality(federated_dataset)} clients')
>>> Loaded dataset with 1000 clients.
print(f'Dataset element types: {next(iter(federated_dataset))[1].element_spec}')
>>> Dataset element types: OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
for id, dataset in federated_dataset.take(5):
print(f'Client [{id}] has {sum(1 for _ in iter(dataset))} examples')
>>> Client [b'f2174_61'] has 106 examples
>>> Client [b'f1378_08'] has 99 examples
>>> Client [b'f1550_34'] has 106 examples
>>> Client [b'f3817_22'] has 106 examples
>>> Client [b'f1000_45'] has 109 examples
tf.data.Dataset.map
然后可以通过替换来创建扁平数据集tf.data.Dataset.flat_map
:
flat_dataset = loaded_ds.flat_map(
lambda id: emnist_train.serializable_dataset_fn(id))
print(f'Dataset element types: {next(iter(federated_dataset))[1].element_spec}')
>>> Dataset element types: OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
print(f'Flat dataset has {sum(1 for _ in flat_dataset):,} examples.')
>>> Flat dataset has 101,619 examples.
推荐阅读
- azure - 预配 Azure 资源时出错 (Ubuntu)
- css - 定位图像和文本响应式
- angular - 贝宝集成中的 ExpressionChangedAfterItHasBeenCheckedError
- sass - 如何从 SASS 中的字符串中修剪空格?
- c++ - 如何将具有 const noexcept 的虚函数从 c++11 转换为 c++17?
- python - 无法删除某些列表元素
- java - 通过单元测试中的字段“userService”表示不满足的依赖关系
- bash - Azure DevOps 通过 curl 或获取任务输出值作为新变量
- css - bootstrap-图像未与顶部对齐
- node.js - dexie 在导出和导入时使用过滤器 (dexie-export-import)