tensorflow-federated - 如何使用 TFF 进行预测?
问题描述
我的问题是:如何使用 Tensorflow Federated 预测此类图像的标签?
完成模型评估后,我想预测给定图像的标签。就像在 Keras 中一样,我们这样做:
# new instance where we do not know the answer
Xnew = array([[0.89337759, 0.65864154]])
# make a prediction
ynew = model.predict_classes(Xnew)
# show the inputs and predicted outputs
print("X=%s, Predicted=%s" % (Xnew[0], ynew[0]))
输出:
X=[0.89337759 0.65864154], Predicted=[0]
以下是 state 和 model_fn 的创建方式:
def model_fn():
keras_model = create_compiled_keras_model()
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),client_weight_fn=None)
state = iterative_process.initialize()
我发现这个错误:
list(self._name_to_index.keys())[:10]))
AttributeError: The tuple of length 2 does not have named field "assign_weights_to". Fields (up to first 10): ['trainable', 'non_trainable']
谢谢
解决方案
(需要 TFF0.16.0
或更新版本)
由于代码是tff.learning.Model
从 a构建的,tf.keras.Model
因此您可以assign_weights_to
在对象上使用该方法tff.learning.ModelWeights
(的类型state.model
)。此方法在联邦学习文本生成教程中使用。
这可能看起来像(靠近底部,早期部分是示例 FL 训练循环):
def create_keras_model() -> tf.keras.Model:
...
def model_fn():
...
return tff.learning.from_keras_model(create_keras_model())
training_process = tff.learning. build_federated_averaging_process(model_fn, ...)
state = training_process.initialize()
for _ in range(NUM_ROUNDS):
state, metrics = training_process.next(state, ...)
model_for_inference = create_keras_model()
state.model.assign_weights_to(model_for_inference)
一旦将权重state
分配回 Keras 模型,代码就可以使用标准的 Keras API,例如tf.keras.Model.predict_on_batch
predictions = model_for_inference.predict_on_batch(batch)
推荐阅读
- wpf - WPF ListView 列宽
- python-3.x - 将字符串转换为 matplotlib 日期以进行绘图
- javascript - 有没有办法修剪 YouTube 视频而不将其下载到本地机器?
- android-recyclerview - 在用户滑动时应用设计并在中止的滑动上删除设计
- python - Docker:env:无法执行'python3':没有这样的文件或目录
- google-colaboratory - 流输出截断到最后 5000 行
- css - 引导 btn 块宽度以适应屏幕
- javascript - 如何循环对象并将键值推送到键值对象中
- hibernate - 错误:org.hibernate.engine.jdbc.spi.SqlExceptionHelper - Exhausted Resultset,嘿,我的休眠代码中出现了这个异常?
- javascript - Express.js 应用程序错误:表单字段值不持久