首页 > 解决方案 > Tensorflow Federated:为什么我的迭代过程无法训练回合

问题描述

我从我自己的数据集中用 TFF 编写了一个代码,除了这一行之外,所有代码都可以正常运行

在 train_data 中,我制作了 4 个数据集,加载了 tf.data.Dataset,它们的类型为“DatasetV1Adapter”

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).map(map_fn).shuffle(500).batch(20)

federated_train_data = [client_data(n) for n in range(4)]

batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())

def model_fn():
  model = tf.keras.models.Sequential([
    .........
  return tff.learning.from_compiled_keras_model(model, batch)   

所有这些都运行正确,我得到了教练和状态:

trainer = tff.learning.build_federated_averaging_process(model_fn)

除了,当我要开始训练和使用这段代码时:

state, metrics = iterative_process.next(state, federated_train_data) 
print('round  1, metrics={}'.format(metrics))

我不能。错误来了!那么,错误可能来自哪里?从数据集的类型?或者我让我的数据联合的方式?

标签: tensorflow2.0tensorflow-federated

解决方案


这是我的代码,我使用 Tensorflow v2.1.0 和 tff 0.12.0

img_height = 200
img_width = 200
num_classes = 2
batch_size = 10

input_shape = (img_height, img_width, 3)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen0 = img_gen.flow_from_directory(par1_train_data_dir,(200, 200),'rgb', batch_size=10)
ds_par1 = tf.data.Dataset.from_generator(gen
    output_types=(tf.float32, tf.float32),
    output_shapes=([None,img_height,img_width,3], [None,num_classes])
)
ds_par2 = tf.data.Dataset.from_generator(gen0 
    output_types=(tf.float32, tf.float32),
    output_shapes=([None,img_height,img_width,3], [None,num_classes])
)

dataset_dict={}
dataset_dict['1'] = ds_par1
dataset_dict['2'] = ds_par2

def create_tf_dataset_for_client_fn(client_id):
    return dataset_dict[client_id]

source = tff.simulation.ClientData.from_clients_and_fn(['1','2'],create_tf_dataset_for_client_fn)

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds


train_data = [client_data(n) for n in range(1)]

images, labels = next(img_gen.flow_from_directory(par1_train_data_dir,batch_size=batch_size,target_size=(img_height,img_width)))
sample_batch = (images,labels)

def create_compiled_keras_model():
  .....

def model_fn():
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()


state, metrics = iterative_process.next(state, train_data)
print('round 1, metrics={}'.format(round_num, metrics))

推荐阅读