张量流联邦学习检查点 [英] tensorflow federated learning checkpoint

查看:43
本文介绍了张量流联邦学习检查点的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在研究带有 tensorflow 联合 API 的 federated_learning_for_image_classification.ipynb.

I am studying a federated_learning_for_image_classification.ipynb with tensorflow federated API.

在示例中,我可以检查每个模拟客户训练的准确度、损失和总准确度、总损失.

In the example, I could check each simulated clients train Accuracy, Loss and Total accuracy, Total loss.

但是没有检查点文件.

我想制作每个客户端检查点文件和总检查点文件.

I want to make each client checkpoint file and total checkpoint files.

然后比较客户端参数变量和总参数变量.

And then compare the client parameter variables and total parameter variables.

谁能帮我在 federated_learning_for_image_classification.ipynb 示例中制作检查点文件?

Anyone can help me to make checkpoint file in federated_learning_for_image_classification.ipynb example?

推荐答案

要问的一个问题是,您是要比较 TFF(作为联合计算的一部分)还是事后变量/outside TFF(在 Python 中分析).

One question to ask is whether you want to compare the variables within TFF (as part of the federated computation) or post-hoc/outside TFF (analyzing within Python).

修改由 tff.learning.build_federated_averaging_process 可能是一个不错的选择.事实上,我建议在 tensorflow_federated/python/research/simple_fedavg/simple_fedavg.py,而不是深入研究 tff.learning.

Modifying the tff.utils.IterativeProcess construction performed by tff.learning.build_federated_averaging_process may be a good way to go. In fact, I'd recommend forking the simplified implementation on GitHub at tensorflow_federated/python/research/simple_fedavg/simple_fedavg.py, rather than digging into tff.learning.

更改norel291ofol#La> 执行 tff.fedetated_mean 从客户端更新到 tff.federated_collect 将列出所有客户的模型,然后可以与全局模型进行比较.

Changing the line that performs a tff.fedetated_mean on the updates from the clients to a tff.federated_collect will will give a list of all the client's models that can then be compared to the global model.

示例:

client_deltas = tff.federated_collect(client_outputs.weights_delta)

@tff.tf_computation(server_state.model.type_signature,
                    client_deltas.type_signature)
def compare_deltas_to_global(global_model, deltas):
  for delta in deltas:
    # do something with delta vs global_model 

tff.federated_apply(compare_deltas_to_global, (server_state.model, client_deltas))

这篇关于张量流联邦学习检查点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆