FL:Heterogeneous model 小结
0. Motivations
- System heterogeneity. Clients have various computation and bandwidth resources, where each participant has capacity and desire to design their own unique model.
- Strong server, weak client.
1. Knowledge Distillation 入门
Ref: 【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 - 潘小小的文章 - 知乎 https://zhuanlan.zhihu.com/p/102038521
知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:
- 原始模型训练: 训练"Teacher模型", 简称为Net-T,它的特点是模型相对复杂。对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。
- 精简模型训练: 训练"Student模型", 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。
Soft Labels(Soft Targets)
最后,Net-S的目标函数有: \[ L=\alpha L_{\text {soft }}+\beta L_{\text {hard }} \]
2. HeteModel-FL with knowledge distillation
2.1 FedHe
FedHe: Heterogeneous Models and Communication-Efficient Federated Learning
2021 17th International Conference on Mobility, Sensing and Networking (MSN) 2021
Server不承担teacher模型训练,只负责聚合各个client上传的各类样本的logits,并将聚合的结果发还。
Clients端把 aggregated logits 视作 soft label 进行学习。
2.2 FedMD
Fedmd: Heterogenous federated learning via model distillation
arXiv preprint 2019
Code: https://github.com/diogenes0319/FedMD_clean
Clients提供一部分数据来构建public dataset。
各 client 求 public dataset 对应的 logits。Server 负责聚合各个 client 的 logits 并求平均。发还的 avg(logits) 用以蒸馏 client 端的 model。
由于各 client 是用private dataset + public dataset训练模型,故对public dataset算出的logits中隐性地包含client private data distribution的信息,意味着使用蒸馏可以在不侵犯隐私的情况下获得其他的client的帮助。
2.3 FML
Federated mutual learning
arXiv preprint 2020
设置 Global model 及 Personalized model。 Global model 架构相同,按照一般FL方式进行训练,作为 teacher model使用。
2.4 FedH2L
FedH2L: Federated learning with model and statistical heterogeneity
arXiv preprint 2021
需要共享部分数据作为seed data(文中未讨论seed data的选择和对模型影响)。
直接将知识蒸馏迁移到去中心化FL场景。client间互为teacher-student,进行知识蒸馏。
2.5 对比
Model | Architecture | Share |
---|---|---|
FedHe | CS | Logits with class |
FedMD | CS | Public dataset |
FML | CS | Global model weights |
FedH2L | P2P | Seed data |
如何兼顾隐私与构建强teacher模型,仍是待讨论的问题。
3. Others
Hyper networks
通过Hyper networks学习personalized model weights,模型自由度较知识蒸馏低。