这是博客

试图存在, 但薛三无法自证

0%

Torch-CopyModel的坑

Main reference: Copy PyTorch Model using deepcopy() and state_dict()

近日验证idea,踩中Pytorch复制模型的大坑,遂记录如下。

1. 概况

问题现象:进行联邦学习+GNN知识蒸馏的idea验证,发现在进行FedAvg后,模型loss、acc等不再变化。

image-20230216104047624

原因概述:复制模型时未对optimizer做重新初始化。

image-20230216103949066

2. 问题发现与探索(简述)

  1. FedAvg+KD时,loss、acc未如预期一般成锯齿形下降,而是在第一次聚合后便不再变化。
  2. 检查模型架构是否异常;
  3. 聚合算法实现检查。但FedAvg + GNN/MLP 正常收敛(这就是坑点,代码实现的单个分析都没有问题);
  4. 检查模型是否设置异常,未开启反向传播。但print(model.weight.grad())不为0,require_grad==True设置没有问题,但每轮训练,weights, grad都不变(接近发现原因了!);
  5. 反思是否是梯度消失。但打印的梯度及其变化情况,与梯度消失现象不一致;
  6. 设置client=1,一个模型不进行deepcopy(),只进行知识蒸馏,可以正常训练。另一个模型进行知识蒸馏,并且每一个epoch将参数复制到相同架构的模型,再通过deepcopy拷贝回来,异常!

3. BUG所在

  1. FedAvg参考他人实现,server模型下发回客户端时使用的是deepcopy

    image-20230216110124249
  2. 知识蒸馏参考他人实现,训练时optimizer作为参数传入:

    image-20230216110251502

    (而我惯常的train(),optimizer对象是在函数内创建,即每次调用train()时都会重新初始化optimizer)

上述二者,每个都能完成原始代码的既定任务,但两者结合使用时,就会发生optimizer无法更新模型参数的问题!

4. Pytorch 模型复制

Copy PyTorch Model using deepcopy() and state_dict()

  • copy.deepcopy():完整地深拷贝整个模型,创建全新的对象。该对象递归地复制原模型的内部对象的值。训练新模型需要对原optimizer重新初始化
  • new_model.load_state_dict(model.state_dict()),首先用户需要自己创建new_model,新模型的架构应与被复制模型的架构一致。load_state_dict only copies parameters and buffers。实践发现不需对原optimizer重新初始化。