首页 > 解决方案 > 是否可以将不同的权重子集发送给不同的客户?

问题描述

我正在尝试使用 tensorflow-federated 在服务器上选择不同的权重子集并将它们发送给客户端。然后,客户将训练并发回训练后的权重。服务器汇总结果并开始新一轮通信。

主要问题是我无法访问权重的 numpy 版本,因此我不知道如何访问每一层的权重子集。我尝试使用 tf.gather_nd 和 tf.tensor_scatter_nd_update 来执行选择和更新,但它们仅适用于张量,而不适用于张量列表(因为 server_state 处于 tensorflow-federated 中)。

有没有人有任何提示来解决这个问题?甚至可以向每个客户发送不同的权重吗?

标签: tensorflowtensorflow2.0tensorflow-federated

解决方案


如果我遵循正确,编写 TFF 类型速记中描述的高级计算的方法是:

@tff.federated_computation(...)
def run_one_round(server_state, client_datasets):
  weights_subset = tff.federated_map(subset_fn, server_state)
  clients_weights_subset = tff.federated_broadcast(weights_subset)
  client_models = tff.federated_map(client_training_fn, 
                                    (clients_weights_subset, client_datasets))
  aggregated_update = tff.federated_aggregate(client_models, ...)
  new_server_state = tff.federated_map(apply_aggregated_update_fn, server_state)
  return new_server_state

如果这是真的,似乎大部分工作都需要进行,subset_fn其中需要服务器状态并返回全局模式权重的子集。通常,模型是 的结构(listdict,可能是嵌套的)tf.Tensor,正如您所观察到的,它不能用作tf.gather_ndor的参数tf.tensor_scatter_nd_update。但是,它们可以逐点应用于张量 uses 的结构tf.nest.map_structure。例如,从三个张量的嵌套结构中选择 [0, 0] 处的值:

import tensorflow as tf
import pprint
struct_of_tensors = {
    'trainable': [tf.constant([[2.0, 4.0, 6.0]]), tf.constant([[5.0]])],
    'non_trainable': [tf.constant([[1.0]])],
}
pprint.pprint(tf.nest.map_structure(
    lambda tensor: tf.gather_nd(params=tensor, indices=[[0, 0]]),
    struct_of_tensors))

>>> {'non_trainable': [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>],
     'trainable': [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.], dtype=float32)>,
                   <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]}

推荐阅读