https://arxiv.org/abs/2206.05507
Federated Learning with GAN-based Data Synthesis for Non-IID Clients
Federated learning (FL) has recently emerged as a popular privacy-preserving collaborative learning paradigm. However, it suffers from the non-independent and identically distributed (non-IID) data among clients. In this paper, we propose a novel framework
arxiv.org
2022 IJCAI accepted
연합학습 상황에서 클라이언트마다 데이터가 non-IID 할 때 로컬 모델 간의 불일치가 발생하여 성능이 저하되는 것을 막기 위해, Synthetic Data Aided Federated Learning 방식인 SDA-FL을 제안한다.
- 각 클라이언트는 GAN을 사전 학습하여 privacy-preserving한 합성 데이터를 생성한다.
- 생성한 합성 데이터는 서버에 업로드되어 글로벌 데이터셋으로 결합된다.
- 서버는 이러한 합성 데이터에 대해 반복적으로 수도 레이블링(pseudo-labeling) 메커니즘을 적용하여 각 데이터의 레이블을 업데이트한다.
- 클라이언트 간 데이터 분포를 유사하게 만들고, 모델 간 consistency를 높여 글로벌 모델의 학습 효율성을 향상한다. 논문에 따르면 SDA-FL은 다양한 벤치마크 데이터셋에서 기존 방법들보다 뛰어난 성능을 보여주었다.
Introduction
- 연합 학습(FL): 개인이 자신의 데이터를 공개하지 않고 글로벌 모델을 학습하기 위해 협력하는 개인 정보 보호 방식
- FedAvg 알고리즘: 클라이언트가 로컬 데이터를 사용하여 모델을 훈련하고, 업데이트된 모델을 파라미터 서버(PS)에 전송하여 집계한다.
- non-IID 문제: 클라이언트 간의 데이터 분포가 비슷하지 않을 때 성능이 저하된다. 이는 로컬 모델 간의 불일치를 초래하고, 글로벌 모델 집계의 효율성을 떨어뜨린다.
- 기존의 해결 방법: 로컬 모델의 정보로 모델을 정규화하려는 접근 방식이 여러 번 제안되었지만 극단적인 non-IID 환경에서는 크게 개선되지 않았다고 한다.
- 새로운 접근법: SDA-FL 프레임워크를 통해, 각 클라이언트가 로컬 GAN을 사용해 차별적인 비공식 합성 데이터를 생성하고, 이를 공유하여 데이터 분포를 균일하게 만든다.
Related Works
Non-IID Challenges in Federated Learning
클라이언트 간 데이터 불균형은 로컬 모델 간의 차이를 증가시키고, 이로 인해 집계된 모델의 성능이 저하된다. 또한 데이터가 비독립적일 경우, 클라이언트의 로컬 모델이 다른 클라이언트와 일치하지 않게 되어, 전체 모델 성능이 악화될 수 있다.
이를 해결하기 위해 다양한 방법론이 로컬 모델의 objective function을 수정하여 글로벌 모델이나 다른 클라이언트의 모델로부터 추가 지식을 활용하도록 제안되었다. 그러나 이러한 방법들은 극단적 non-IID 환경에서는 그다지 만족스러운 성능을 보장하지 못한다. 로컬 클라이언트에서 동일한 모델 구조를 사용하는 것 외에도, 각 클라이언트의 모델 구조를 조정하거나, 파라미터 서버(PS)에서 모델 집계, 클라이언트 선택, 클라이언트 클러스터링 등을 최적화하여 데이터의 이질성을 해결하는 연구도 진행되었다.
Data Augmentation and Privacy Preserving
클라이언트 데이터를 증강하여 데이터 분포 불일치 문제를 해결하려는 시도도 있었다. 대표적으로 Mixup 방법론은 클라이언트들이 자신들의 데이터를 혼합하여 새로운 글로벌 데이터를 구성한다. 이때 글로벌 데이터를 생성하는 과정에서 개인 정보 유출의 위험이 있다.
GAN/VAE 등의 생성 모델을 기반으로 데이터를 증강하는 방법은, 클라이언트가 초기 데이터 일부를 서버에 업로드하여 생성 모델을 훈련시킨다. 이후 클라이언트가 생성 모델을 다운로드하여 사용한다. 하지만 로컬 데이터 샘플을 서버에 보내기 때문에 프라이버시 문제가 생길 수 있다. 대표적으로 FedDPGAN은 모든 클라이언트가 FL 프레임워크를 기반으로 글로벌 생성 모델을 공동으로 학습하여 부족한 로컬 데이터를 보완하는 방법이다. GAN 훈련 과정에서는 클라이언트-서버 간 자주 생성 모델을 교환해야 하며, 이로 인해 높은 커뮤니케이션 비용이 발생한다는 단점이 있다.
Method
클라이언트로부터 생성된 synthetic data가 서버로 보내지고, 이 데이터로 서버는 글로벌 데이터셋을 구성한다. 이후 학습 라운드마다 클라이언트는 글로벌 모델을 다운로드하여, 합성 데이터의 수도 레이블을 생성한다. 수도 레이블이 서버로 보내지면 이를 집계하여 글로벌 모델을 업데이트한다.
이전 연구(Jeong et al., 2020)에서는 각 클라이언트가 다른 클라이언트의 로컬 모델을 활용하여 레이블이 없는 데이터를 라벨링했는데 이는 데이터 분포 왜곡(skewed)을 일으켜 bottleneck이 발생하는 문제가 있었다. SDA-FL은 자신의 로컬 데이터로 훈련된 로컬 모델과 해당 모델에 의해 생성된 합성 데이터(Pseudo Label)와의 높은 일관성을 유지할 수 있다.
synthetic data에서의 pseudo label은 아래 절차로 업데이트 된다.
우선 unlabeled synthetic data 각각의 인스턴스 x에 대해, PS는 각 로컬 모델 W_k(t+1)을 수신한 후 pseudo label을 할당한다. 이때 각 x에 대한 최대 클래스 확률 f_c(w_k(t+1); x)가 미리 정의된 임계값 τ보다 높은 경우에만 해당 인스턴스의 pseudo label이 할당된다. 이렇게 함으로써, FL 과정의 처음에는 훈련이 덜 된 로컬 모델들의 pseudo label 품질이 낮을 수 있지만, 각 라운드마다 모델들이 점진적으로 강력해지며 pseudo labels의 품질이 향상되는 효과가 있다.
data augmentation을 위해 GAN 말고도 Mixup 기법이 사용된다. Mixup은 실제 데이터와 합성 데이터를 혼합하여 새로운 훈련 샘플을 만드는 방식이다. 이는 데이터를 선형 보간하여 가상의 샘플을 생성하는 것으로, 모델의 일반화 능력을 높여준다. SDA-FL에서는 합성 데이터(ˆX)와 실제 데이터(Xt i,e)를 함께 사용하여 Mixup을 적용한다.
t번째 통신 라운드에서 i번째 클라이언트의 mixup 데이터와 그 레이블, 로컬 모델의 mixup loss는 다음과 같이 정의된다.
$$ \tilde{X}_{i,e}^{t} = \lambda_1 \hat{X}_e + (1 - \lambda_1) X_{i,e}^{t}, $$
$$ \tilde{Y}_{i,e}^{t} = \lambda_1 \hat{Y}_e + (1 - \lambda_1) Y_{i,e}^{t}, $$
$$ \ell_1 = \lambda_1 \ell \big(f(\tilde{X}_{i,e}^{t}; w_t), \tilde{Y}_e^{t} \big)
+ (1 - \lambda_1) \ell \big(f(\tilde{X}_{i,e}^{t}; w_t), Y_{i,e}^{t} \big). $$
위에서 정의한 loss가 FL 프로세스 초기에 pseudo label의 신뢰성 부족으로 불안정한 양상을 띄기 때문에 논문에서는 또 다른 loss 항을 도입했다. 실제 배치 샘플에 대한 cross entropy로 아래와 같이 정의된다.
$$ \ell_2 = \ell \big(f(X_{i,e}^{t}; w_t), Y_{i,e}^{t} \big). $$
마지막으로 SGD를 통해 로컬 모델이 업데이트된다. 기존의 FL 방식에서는 PS가 데이터에 접근하지 않아 고립되어 있지만, SDA-FL에서는 서버가 global synthetic dataset을 보유하여 이를 통해 글로벌 모델을 훈련할 수 있다.
$$ w_{t+1}^{k} \leftarrow w_{t}^{k} - \eta_t \nabla (\ell_1 + \lambda_2 \ell_2). $$
Experiments
실험을 위한 데이터셋으로 MNIST, FashionMNIST, CIFAR-10, SVHN 4개의 벤치마크 데이터셋을 사용하고, 모든 클라이언트에게 동일하게 훈련 샘플을 분배했다. 10개의 클라이언트를 배치하여 실험을 진행하고 각 라운드에서 모든 클라이언트를 선택한다. 훈련 과정에서 각 클라이언트는 4,000개의 합성 샘플을 생성하여 PS에 업로드하고, 모든 방법에 대해 200개의 라운드 동안 학습한다. SDA-FL을 FedAvg, FedProx, SCAFFOLD, Naivemix, FedMix, FedDPGAN과 비교하여 성능 평가를 진행하였다.
실험 결과 SDA-FL이 기존 방법들보다 우수하다는 것을 보여주었다. 여러 클래스의 데이터가 클라이언트에 배포된 경우에도 SDA-FL의 성능이 두드러지며, COVID-19 데이터셋에서도 SDA-FL이 FedDPGAN보다 1.68% 더 높은 정확도를 기록했다.
Conclusion
SDA-FL은 학습 라운드마다 synthetic data가 로컬 모델의 성능을 향상시키고, 업데이트된 로컬 모델은 pseudo labeling과 global synthetic dataset 구축에 쓰인다는 특징을 가지고 있다. 이 과정에서 라벨의 신뢰도가 향상되어 데이터 품질 손실이 발생하지 않는 것이 장점이다.
즉, 전통적인 FL 방법론에서 추가적인 데이터 생성 작업을 도입하여, 기존 방법에서 클라이언트들이 로컬 데이터를 바탕으로만 모델을 업데이트해서 발생하는 성능 저하(데이터가 non-IID 할수록 distribution skewed)를 줄여준다. 또한 PS가 단순한 aggregation만 수행하는 기존의 방법론과 달리 SDA-FL에서는 고신뢰도의 합성 데이터를 사용하여 글로벌 모델을 향상한다.
non-IID 한 클라이언트 데이터 분포를 완화해서 FL에서의 학습 성능을 높이는 방법론이 많이 연구되고 있다. 앞으로 더 공부해 보면서 연구 동향을 알아봐야겠다.