python - 我可以在 Tensorflow 联邦学习 (TFF) 的 keras 模型中使用 class_weight
问题描述
我的数据集是类不平衡的,所以我想使用 class_weight 来启用分类器重权次要类。在一般情况下,我可以按如下方式分配班级权重:
weighted_history = weighted_model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels),
# The class weights go here
class_weight=class_weight)
有什么方法可以在 tensorflow 联邦学习中分配 class_weight 吗?我的联邦学习代码如下:
def create_keras_model(output_bias=None):
return tf.keras.models.Sequential([
tf.keras.layers.Dense(12, activation='relu', input_shape(5,)),
tf.keras.layers.Dense(8, activation='relu'),
tf.keras.layers.Dense(5, activation='relu'),
tf.keras.layers.Dense(3, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')])
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy()])
解决方案
不是直接的。主要问题是该tf.keras.Model.fit
方法在概念上并未映射到从分散数据进行训练的想法。
如果你想让这个工作TFF
,第一步是确定应该执行什么算法。据我所知,这没有一个明显的答案——例如,class_weights
如果你不能直接访问数据,你如何确定那些是什么?
但是让我们假设您以某种方式获得了这些信息,并且只是想修改客户的本地培训程序。从开始examples/simple_fedavg
,实现它的方法是适当地修改在这个循环中计算梯度的方式。
推荐阅读
- php - 在 laravel 中显示数据库中的单个记录
- python - 字典计算关键词在文章中出现的次数
- python - 将 Pyhon 2 转换为 Python 3 时的编码问题(使用 lmdb)
- python - 从 django 的角度来看,“应用程序”是如何工作的?
- smbj - SMBJ:将包含文件的目录复制到我的本地计算机的 SMBJ api 是什么
- regex - 使用正则表达式,我需要匹配除特定日期格式之外的所有内容
- javascript - jQuery Datatables - 无法从隐藏页面获取输入值
- html - 位置设置为固定时导航栏缩小
- html - 缩放背景但不缩放儿童
- oracle - 如何在 PL/SQL 开发人员中查看模式?