祝贺!顾嘉关于非中心化联邦学习统计推断的论文被接受发表

燕东数据派 2024年09月20日 09:00 北京

近日,陈松蹊教授团队的研究论文《非中心化联邦学习统计推断》(Statistical Inference for Decentralized Federated Learning)被统计学期刊《统计年刊》(The Annals of Statistics)接受发表。

 

联邦学习(Federated Learning,FL)是一种分布式机器学习方法,它允许多个参与方(如客户端、设备、组织或数据持有者)在不共享原始数据的情况下共同训练一个全局模型。与传统的集中式学习方法不同,联邦学习将模型训练的计算任务分散到各个客户端或设备上,而不是将数据集中在一个中心化的服务器上。作为一种新型分布式协同训练框架,联邦学习已在多个领域得到广泛应用(如Google Gboard和Apple Siri等,图1)。


wps7.jpg

wps8.jpg



1:谷歌Gboard输入法(左)和苹果手机Siri(右)中联邦学习的应用示意。(注:图片由DALL-E生成,提示词由ChatGPT4o改进。)

 

现有的联邦学习理论分析大多集中于基于固定步长的随机梯度下降算法(stochastic gradient descent, SGD)算法,这种情况下通常会导致渐近有偏的估计量。此外,现有的研究倾向于将参与训练的客户端数量视为固定的,从而未能反映大规模联邦学习系统中客户端数量可以随本地样本量发散这一现实。

 

文章探讨了基于最通用的非中心化联邦学习(DFL, 图2)算法的异质性M-估计的统计推断问题,该算法囊括了许多SGD算法作为特例。文章推导了估计量空间平均轨迹的均方误差(MSE)上界以及各客户端本地估计之间的共识误差(consensus error)上界,还证明了基于DFL算法的Polyak-Ruppert(PR)平均估计量的渐近正态性。研究表明,当联邦学习系统规模不是特别大的情形下(相对于每个客户端本地样本量T),PR-估计量的渐近方差和全样本M-估计量相同,这意味着PR平均估计器是统计渐近有效(efficient)的。

 

1729837064_39310.jpg

1729837078_22259.jpg

2:一个具有6个节点的连接网络(左),以及Metropolis-Hastings法则下对应的连接矩阵C。

 

为了在更大规模的非中心化联邦学习系统中进行统计推断,文章提出了一种计算高效、统计有效的一步更新(one-step update)估计量。该一步更新估计量以步长较小的PR平均估计量作为初始估计量,并通过一个修正项来改进其统计效率。对于损失函数是否光滑(具有二阶可微性)的不同情形,文章针对性地给出了置信域的构造方法,并且建立了对应的理论保证。

 

顾嘉,2024年7月博士毕业于北京大学统计科学中心,现为浙江大学数据科学研究中心助理教授,是论文的第一作者陈松蹊教授是论文的通讯作者,也是顾嘉的博士生导师。研究得到了国家自然科学基金项目Nos.12292980, 12292983 和No.92358303的资助。

 

论文原文链接:

https://songxichen.com/Uploads/Files/Publication/Main-AoS-0912.pdf