为了解决联邦学习中异构数据导致的客户端漂移以及灾难性遗忘产生的针对下一篇CCF A的idea。
联邦学习痛点分析
现在来看联邦学习面对的主要问题是:
- 数据异构性
- 模型异构性
- 灾难性遗忘(本地训练过拟合,覆盖了全局模型的知识)
其中数据异构性会导致客户端漂移的情况发生
客户端漂移:
是指在联邦学习中,参与训练的客户端(设备或节点)在不同轮次的训练中可能产生模型参数的漂移或变化,导致其模型的性能逐渐变差。这种漂移可能是由于不同客户端的数据分布不同,或者在训练过程中客户端本身的模型更新不稳定等原因引起的。
Example:
医院A的患者主要是老年人,而医院B的患者主要是年轻人。在训练过程中,医院A的模型可能会逐渐更好地适应老年人的健康情况,导致模型在老年人数据上的准确性提高,但在年轻人数据上性能下降。相反,医院B的模型可能会更好地适应年轻人数据,导致在年轻人数据上的准确性提高,但在老年人数据上性能下降。
灾难性遗忘:
当一个模型在学习新任务时,会导致其在之前已经学习过的任务上性能下降的现象。这种情况可能发生在机器学习中,尤其是在深度学习领域,当一个模型被迫忘记之前学习的知识,以适应新的任务时。
Example:
考虑一个语言模型,最初它被训练用于生成英文文本。随后,这个模型被重新调整为执行另一个任务,比如图像分类。在调整过程中,模型需要学习识别图像中的不同对象和特征。然而,由于新任务的训练数据不同于之前的文本数据,模型可能会不可避免地"忘记"如何生成文本。结果是,即使在图像分类任务上性能提高了,模型在生成文本方面的能力可能会受到影响,导致其在之前的任务上的性能下降。
解决客户端漂移的一些方法
联邦交叉相关学习
采用自监督学习得到一个泛化模型,采用的是无标签的公共数据,增加相同类别特征的不变性和不同类别特征的差异性。
核心思想是通过该矩阵捕捉不同设备之间的特征关联性。帮助模型更好地理解不同设备的数据分布和特征分布之间的差异。从而提升模型性能。
通过过去全局平均下降方向预测全局下降方向,对漂移进行修正
文章在服务器端和每个客户端中都设置了变量\(c\)
服务器端的参数是全局模型往最优模型的梯度下降方向,客户端中的参数是客户端的梯度更新方向
其实就是服务器端利用客户端在K次训练的梯度下降方向的平均值作为下次梯度下降方向的预测方向。
每次客户端参数在聚合的时候,要先去除掉之前的平均下降梯度,然后再补充上预测的全局梯度下降方向。
定义局部漂移变量并将其放入损失函数中
为客户端定义了局部漂移变量,该变量应该为全局模型与局部模型的差值,将其放入到目标函数中。
这个方法和SCAFFOLD的区别就是,将惩罚项放入到目标函数中,会在局部模型训练的时候就向全局模型靠拢,然后再加上SCAFFOLD的修正项
在全局模型聚合后利用本地模型提取的知识微调全局模型
在每一轮通信中,FedFTG随机选择一组客户端,向他们广播全局模型。每个客户端使用全局模型初始化本地模型,并使用本地优化器对其进行培训。服务器收集本地模型并将其聚合为一个初步的全局模型。FedFTG没有将聚合的模型直接广播回每个客户机,而是使用从本地模型中提取的知识在服务器中微调这个初步的全局模型。
维护一个条件生成器来生成和真实数据分布一样的伪数据,并通过定义损失函数来获得硬样本(不容易区分的样本)作为知识蒸馏所用的训练集。
利用每个客户端的伪数据和客户端模型进行知识蒸馏微调。
利用公共数据集作为客户端沟通桥梁
利用客户端在公共数据集的输出来获得该客户端的知识分布,利用KL散度对知识分布差异进行量化。差异越大越说明这两个客户端之间相互学习的越多。每个客户端都保存其与其他所有客户端的知识分布差异。在训练的时候,保证其与其他客户端的知识分布差异最小。
解决灾难性遗忘的一些方法
预训练模型和全局模型双重蒸馏
为了避免在本地训练的时候模型参数逐渐覆盖之前的参数,该方法通过兼顾本地学习的知识和其他客户端学习知识的蒸馏方法。
文章将全局模型当做领域间的教师模型,因为这个模型包含了所有客户端的知识,可以用来减少过拟合。
文章将客户端在本地数据集上预训练的模型当做本地的教师模型,用来传授本地的知识。
利用Null Space空间,清除对过往模型的干扰
Training networks in null space of feature covariance for continual learning. CVPR. 2021.
Learning Federated Visual Prompt in Null Space for MRI Reconstruction. CVPR. 2023.
为什么在零空间内进行梯度下降可以帮助避免灾难性遗忘呢?
不影响先前任务的方向: 如果你的更新在之前任务的零空间内,那么这意味着你正在在一个方向上更新权重,这个方向不会改变之前任务的输出。因此,通过限制梯度下降的方向在零空间内,你确保了新任务的学习不会“干扰”或“忘记”之前的任务。
权重空间的分离: 通过限制更新到零空间,你实际上是在权重空间中为不同的任务创建了分离的“子空间”。这意味着每个任务都有自己的特定方向或子空间进行权重更新,而不会影响其他任务。
充分利用网络容量: 而不是让所有任务共享同一个权重空间,使用零空间技术使得不同任务可以在不同的子空间内找到其最优权重。这有效地利用了整个网络的容量,同时减少了任务间的互相干扰。
2023.8暑假想法
20230727:
让服务器保存一个N*N的矩阵,矩阵内部保存每个服务端对应的聚合参数。初始的聚合参数可以通过两个客户端数据分布的相似性(KL散度)进行赋值,相似性高的客户端的权重更高。客户端可以生成一些符合自身数据集分布的伪数据(保证隐私)。为了简化计算,可以进行分组,将尽量相似的客户端分成一组。
借鉴:Layer-wised Model Aggregation for Personalized Federated Learning
20230805:
我们要解决的问题是什么:异构的数据+异构的模型 -> 非IID导致模型漂移+灾难性遗忘;异构的模型无法直接聚合
初步的想法:利用20230727的idea构造specialist models,然后借鉴《Distilling……》论文中的方式,通过知识蒸馏将specialist model的知识统一起来。
如果我们引入了Specialist model,就避免了灾难性遗忘的问题,我们直接就让他遗忘所有知识。
20230807:
可不可以将硬样本的数据挖掘出来,通过硬样本的模型蒸馏进一步提升模型性能。
模型蒸馏:需要的是专才模型,通过专家的指导,可以提升模型在这个层面的能力。
联邦学习聚合:需要的是相似的模型,通过聚合的形式提升能力。
借鉴:Fine-tuning Global Model via Data-Free Knowledge Distillation for Non-IID Federated Learning
20230808:
利用specialist models 将全局每个类别性能最好的模型进行蒸馏出一个全局模型
在SCAFFOLD论文中,作者预测了每个客户端的梯度下降方向,用的是平均值,能否在预测上进行改进?
20230904
感觉还是可以通过图分割,每个客户端会根据KL散度构建全连通图,然后利用图分割将设备分成各个组,保证每个组内的数据分布是相似的。(多粒度也可以)。然后每个组内的模型通过模型聚合形成专才模型。专才模型会对global模型进行一个蒸馏微调操作。