Effective Heterogeneous Federated Learning via Efficient Hypernetwork-based Weight Generation
While federated learning leverages distributed client resources, it faces challenges due to heterogeneous client capabilities. This necessitates allocating models suited to clients' resources and careful parameter aggregation to accommodate this heterogene
arxiv.org
Effective Heterogeneous Federated Learning via Efficient Hypernetwork-based Weight Generation (ACM SenSys 2024)
이 논문에서는 연합학습의 client heterogeneity에 효율적인 연산 프레임워크 HypeMeFed를 제안한다. 기존 FL 모델들은 모든 클라이언트가 동일한 신경망 아키텍처를 공유한다고 가정하지만, 실제 환경에서는 각 클라이언트가 보유한 기기의 성능이 상이하여 모델 학습 과정에서 불균형이 발생한다.
HypeMeFed에서 제안하는 아키텍처는
1. 멀티-익싯(Multi-Exit) 네트워크
2. 하이퍼네트워크 기반 모델 가중치 생성(Hypernetwork-based Weight Generation)
이 둘을 결합하여 heterogeneous 환경에서 효과적으로 학습할 수 있도록 한다. main contribution은 다음과 같다.
- 클라이언트의 하드웨어 성능 차이로 인해 FL에서는 학습 효율성 및 공정성 문제가 발생한다. HypeMeFed는 multi-exit 네트워크를 활용하여 모델의 깊이를 조정하고, 클라이언트별 적절한 서브모델을 제공하여 학습 성능을 최적화한다.
- 일반적으로 클라이언트마다 다른 네트워크 깊이를 가질 경우, 레이어 간 정보의 불균형이 생길 수 있다. 하이퍼네트워크는 부족한 레이어의 가중치를 보완하고 서버에서 일괄적으로 학습해서, 최적화된 모델을 클라이언트 별로 제공한다.
- 하이퍼네트워크의 메모리 및 연산 부담을 줄이기 위해 SVD 기반 저순위 행렬 분해 기법(Low-Rank Factorization)을 적용한다. 이를 통해 하이퍼네트워크 메모리 요구량을 98.22% 감소시키고, 연산 속도를 1.86배 가속화했다.
Multi-Exit Network
네트워크를 여러 개의 출구(출력 레이어)로 구성하고, 모든 클라이언트는 첫 번째 출구까지 학습한다. 그 후에는 클라이언트 스펙에 따라 학습 정도, 즉 학습이 종료되는 출구를 다르게 할 수 있다. 이로써 발생하는 문제점은, 서버의 각 exit마다 사용할 수 있는 샘플 양이 달라진다는 것이다(information disparity). 논문에서는 이를 HyperNetwork로 해결한다.
HyperNetwork
하이퍼네트워크란 네트워크를 위한 네트워크라고 보면 된다. 다른 네트워크(target network)의 가중치를 예측하여 생성하는 방식으로 동작한다. 일반적인 딥러닝 모델에서 학습을 통해 직접 가중치를 업데이트하며 생기는 문제들(메모리 및 연산 오버헤드, personalization 등)을 해결할 수 있다.
또한 FL 학습에서 클라이언트의 리소스 한계로 일부 레이어의 가중치를 생성하지 못했을 때, 하이퍼네트워크가 기존 학습된 레이어 정보를 토대로 부족한 가중치를 생성할 수 있다.
📌 HypeMeFed의 Hypernetwork 설계
- MLP 기반 구조
- 하이퍼네트워크는 입력으로 이전 레이어의 가중치(Preceding Layer Weights)를 벡터화하여 사용.
- 이를 MLP 모델을 통해 변환하여 다음 레이어의 가중치를 생성하는 방식.
- 두 개의 선형 레이어(fully connected layers)와 ReLU 활성화를 포함하여 비선형성을 확보함.
- Low-Rank Factorization (LRF) 기반 최적화
- 모델 가중치 전체를 직접 예측하는 것이 아닌, Singular Value Decomposition (SVD) 기반 LRF를 활용.
- 이를 통해 신경망의 주요 정보만을 유지하면서 하이퍼네트워크의 연산량을 감소시킴.
- 압축된 Left/Right Singular Vectors 예측하는 방식으로 학습.
- 결과적으로, 하이퍼네트워크는 원래의 전체 모델보다 훨씬 작은 크기의 파라미터를 다룰 수 있음.
- 연합 학습(Federated Learning)에서의 하이퍼네트워크 학습
- 하이퍼네트워크는 서버에서 학습하며, 클라이언트가 학습한 모델의 가중치를 활용하여 업데이트됨.
- 특정 클라이언트가 학습한 서브모델이 부족한 경우, 이전 레이어 정보를 이용해 부족한 레이어의 가중치를 생성.
- 모든 클라이언트의 학습 데이터를 통합하여 점진적으로 향상됨
물론 하이퍼네트워크가 서버에서 작동하더라도 요구되는 리소스가 점점 커져서, 큰 메모리 사용량과 컴퓨팅 타임을 필요로 할 수 있다. 하이퍼네트워크가 메모리 요구 사항으로 인해 서버에서 효과적으로 실행되기 어려울 수 있기 때문에, 저순위 분해(Low-Rank Factorization, LRF) 기반의 신경망 매개변수 압축 방식을 제안하여 하이퍼네트워크 작업을 최적화한다.
전체적인 아키텍처를 표현하면 위 그림과 같다.
① Multi-exit Model Split : 서버는 클라이언트 컴퓨팅 리소스에 맞게 depth-wise로 모델 파라미터를 나눈다.
② Model Distribution : 클라이언트 각각에게 full, medium, small로 나눈 모델(각각 full 모델의 subset)을 분배한다.
③ Local Training : 클라이언트의 local data를 이용한 모델 훈련이 진행된다. 이때 여러 exit에서 발생한 loss의 총합인 joint loss를 이용해 학습한다는 특징이 있다.
④ Model Update : 훈련을 마친 모델 파라미터를 서버에 전송한다.
⑤ HyperNetwork Training, ⑥ Weight Generation : 클라이언트의 학습된 모델에서 추출한 파라미터를 바탕으로 더 깊은 레이어의 가중치를 생성하는 단계. 이로써 information disparity를 극복한다.
⑦ Weight Aggregation : 서버는 글로벌 모델을 업데이트하기 위해 클라이언트에게 수신한 파라미터와, 하이퍼네트워크로 생성된 가중치 파라미터를 집계한다.
Experiments
- HypeMeFed는 기존 FedAvg 대비 모델 정확도를 5.12% 향상시켰으며, FedAvg-S(가장 작은 모델만 사용하는 경우)보다 높은 성능을 유지.
- GPU 기반 시뮬레이션 및 임베디드 장치(Raspberry Pi, Jetson Nano 등)에서 실험을 수행하여 실제 환경에서도 연합 학습 성능을 개선함을 확인.
- 다양한 데이터셋(SVHN, STL10, UniMiB SHAR)을 활용한 실험에서 HypeMeFed가 기존 HeteroFL, ScaleFL 대비 높은 정확도를 유지하며, 연산 비용을 최적화함.
Conclusion
HypeMeFed는 FL의 heterogeneous한 환경에서도 효율적인 연합 학습을 가능하게 하는 프레임워크로, multi-exit 네트워크를 통해 모델 깊이를 조정하고, 하이퍼네트워크를 이용하여 불균형한 학습 데이터를 보완하는 방법을 제안한다. 이를 통해 연산 비용을 낮추면서도 높은 학습 성능을 유지하는 것이 가능하며, 실제 모바일 및 임베디드 환경에서도 실용적으로 적용될 수 있음을 입증하였다. 추가로 디벨롭할 수 있는 방안으로는
1. 현재 Hypernetwork는 LRF 기법을 통해 메모리 연산량을 효율적으로 줄이고 있지만, 여전히 최적화의 여지가 존재한다. 가령 레이어 별로 다른 Hypernetwork 구조를 적용하여 성능을 최적화할 수 있다.
2. 현재는 CNN-based model로 실험이 진행되었으나 LSTM, Transformer 같은 다양한 NN들에도 HypeMeFed를 적용하여 볼 수 있다.
3. 기존 FedAvg와 같은 aggregation method와 병합하여 사용할 수 있다.
설 연휴동안 눈앞에 보이는 건 다 입으로 넣었더니 살이 많이 쪘다... 이제 다이어트 해야지!