Main reference: Copy PyTorch Model using deepcopy() and state_dict()
近日验证idea,踩中Pytorch复制模型的大坑,遂记录如下。
1. 概况
问题现象:进行联邦学习+GNN知识蒸馏的idea验证,发现在进行FedAvg后,模型loss、acc等不再变化。
原因概述:复制模型时未对optimizer做重新初始化。
2. 问题发现与探索(简述)
- FedAvg+KD时,loss、acc未如预期一般成锯齿形下降,而是在第一次聚合后便不再变化。
- 检查模型架构是否异常;
- 聚合算法实现检查。但FedAvg + GNN/MLP 正常收敛(这就是坑点,代码实现的单个分析都没有问题);
- 检查模型是否设置异常,未开启反向传播。但
print(model.weight.grad())
不为0,require_grad==True
设置没有问题,但每轮训练,weights, grad
都不变(接近发现原因了!); - 反思是否是梯度消失。但打印的梯度及其变化情况,与梯度消失现象不一致;
- 设置client=1,一个模型不进行
deepcopy()
,只进行知识蒸馏,可以正常训练。另一个模型进行知识蒸馏,并且每一个epoch将参数复制到相同架构的模型,再通过deepcopy
拷贝回来,异常!
3. BUG所在
FedAvg参考他人实现,server模型下发回客户端时使用的是
deepcopy
:知识蒸馏参考他人实现,训练时
optimizer
作为参数传入:(而我惯常的
train()
,optimizer对象是在函数内创建,即每次调用train()
时都会重新初始化optimizer)
上述二者,每个都能完成原始代码的既定任务,但两者结合使用时,就会发生optimizer无法更新模型参数的问题!
4. Pytorch 模型复制
见Copy PyTorch Model using deepcopy() and state_dict()
copy.deepcopy()
:完整地深拷贝整个模型,创建全新的对象。该对象递归地复制原模型的内部对象的值。训练新模型需要对原optimizer重新初始化;new_model.load_state_dict(model.state_dict())
,首先用户需要自己创建new_model
,新模型的架构应与被复制模型的架构一致。load_state_dict
only copies parameters and buffers。实践发现不需对原optimizer重新初始化。