연합학습/추론 성능

Non-IID 환경을 위한 병합 전략: FedProx

FedTensor 2025. 11. 16. 19:24

FedProx(Federated Proximal)는 연합학습, 특히 데이터가 이질적인(Non-IID) 환경에서 발생하는 'Client Drift (클라이언트 편향)' 문제를 해결하기 위해 제안된 핵심적인 알고리즘입니다.

FedAvg(표준 연합학습)의 직접적인 개선안으로, 로컬 학습 과정에 간단하면서도 효과적인 제약 사항을 추가한 것이 특징입니다.

1. FedProx가 해결하려는 핵심 문제: Client Drift

표준 FedAvg에서는 각 클라이언트가 서버로부터 글로벌 모델을 받아, 자신의 로컬 데이터로 여러 에포크(epoch) 동안 학습합니다.

  • 문제점: 만약 클라이언트 A(예: 숫자 '3' 이미지 만 보유)와 클라이언트 B(예: 숫자 '7' 이미지 만 보유)의 데이터가 매우 이질적이라면,
    • 클라이언트 A의 로컬 모델은 '3'에 과적합됩니다.
    • 클라이언트 B의 로컬 모델은 '7'에 과적합됩니다.
  • Client Drift: 이처럼 로컬 모델이 각자의 데이터셋에 과적합되어 서버의 글로벌 모델($w_t$)로부터 너무 멀리 '표류(drift)'하는 현상이 발생합니다.
  • 결과: 이렇게 멀리 떨어진 모델들을 서버에서 단순 평균($w_{t+1}$)내면, 글로벌 모델의 성능이 불안정해지거나 수렴이 느려집니다.

2. FedProx의 핵심 아이디어: "짧은 목줄 (Proximal Term)"

FedProx는 로컬 모델이 글로벌 모델로부터 너무 멀리 도망가지 못하도록 '근접 항(Proximal Term)'이라는 일종의 페널티, 즉 '짧은 목줄'을 채웁니다.

로컬 모델($w$)이 학습할 때, 자신의 로컬 손실(Loss)만 최소화하는 것이 아니라, "글로벌 모델($w_t$)과의 거리도 가깝게 유지하도록" 강제하는 것입니다.

3. 동작 원리 및 수학적 공식

FedProx는 클라이언트의 로컬 목적 함수(Objective Function)를 수정합니다.

  • 표준 FedAvg의 로컬 목적 함수:
    • 각 클라이언트 $k$는 자신의 로컬 데이터셋에 대한 손실 함수 $H_k(w)$를 최소화하는 $w$를 찾습니다.

$$\min_w H_k(w)$$

  • FedProx의 로컬 목적 함수:
    • FedProx는 여기에 근접 항을 추가합니다.

$$\min_w L_k(w) = H_k(w) + \frac{\mu}{2} \|w - w_t\|^2$$

공식 해설:

  • $H_k(w)$: 클라이언트 $k$의 원래 손실 함수입니다. (예: Cross-Entropy Loss)
  • $w$: 현재 학습 중인 로컬 모델의 가중치입니다.
  • $w_t$: 이번 라운드($t$)에 서버로부터 다운로드한 글로벌 모델의 가중치입니다.
  • $\|w - w_t\|^2$: 근접 항(Proximal Term)입니다. 로컬 모델 $w$와 글로벌 모델 $w_t$ 사이의 L2 거리(Euclidean 거리의 제곱)를 의미합니다.
  • $\mu$ (뮤): 하이퍼파라미터입니다. 이 '목줄'의 강도를 조절합니다.

4. 하이퍼파라미터 $\mu$ (뮤)의 역할

$\mu$ 값은 FedProx의 동작을 결정하는 매우 중요한 요소입니다.

  • $\mu = 0$ 일 때:
    • $\frac{\mu}{2} \|w - w_t\|^2$ 항이 0이 됩니다.
    • FedProx의 목적 함수는 표준 FedAvg의 목적 함수 ($H_k(w)$)와 완전히 같아집니다.
    • 즉, FedProx는 FedAvg를 일반화한 버전이며, FedAvg는 $\mu=0$인 특수한 경우(special case)로 볼 수 있습니다.
  • $\mu > 0$ (값이 클 때):
    • '목줄'이 매우 짧고 강하게 당겨지는 것과 같습니다.
    • 로컬 모델 $w$가 글로벌 모델 $w_t$에서 조금만 벗어나도 큰 페널티($\frac{\mu}{2} \cdot \text{거리}$)를 받습니다.
    • 결과: 로컬 모델은 자신의 데이터에 맞게 거의 학습하지 못하고 $w_t$ 근처에 머무릅니다. Client Drift는 확실히 줄어들지만, 로컬 데이터의 특성을 반영하지 못해 성능이 낮아질 수 있습니다.
  • $\mu > 0$ (값이 적절할 때):
    • 로컬 모델은 $H_k(w)$를 최소화(로컬 데이터 학습)하면서, 동시에 $w_t$와 너무 멀어지지 않도록(페널티) 균형을 맞춥니다.
    • Client Drift를 효과적으로 억제하면서 Non-IID 환경에서도 안정적인 수렴을 가능하게 합니다.

5. FedProx가 해결하는 또 다른 문제: 시스템 이질성

FedProx는 통계적 이질성(Non-IID)뿐만 아니라 시스템 이질성(Systems Heterogeneity) 문제에도 강점을 보입니다.

  • 시스템 이질성: 클라이언트마다 하드웨어 성능(CPU, 배터리)이 달라, 어떤 클라이언트는 로컬 에포크를 1번만, 어떤 클라이언트는 20번 수행할 수 있는 상황을 말합니다.
  • FedAvg의 문제: 로컬 에포크를 20번 수행한 클라이언트는 Client Drift가 심하게 발생한 모델을 서버에 제출하게 되어, 1번 수행한 클라이언트의 기여를 무시하고 글로벌 모델을 오염시킬 수 있습니다.
  • FedProx의 해결: $\mu$ 항이 일종의 '안전장치' 역할을 합니다. 로컬 에포크를 20번 돌리더라도 $w_t$로부터 너무 멀리 벗어날 수 없도록 억제합니다. 따라서 연산 능력이 다른 기기들이 학습에 참여해도 안정성이 유지됩니다.

장점 및 단점 요약

장점

  1. Non-IID 환경 성능: Client Drift를 직접적으로 제어하여 Non-IID 데이터 환경에서 FedAvg보다 훨씬 안정적이고 높은 성능을 보입니다.
  2. 시스템 이질성 보완: 로컬 연산량이 다른(E.g., $E=1$ vs $E=20$) 클라이언트들이 혼재된 상황에서도 안정적으로 동작합니다.
  3. 이론적 수렴 보장: 원 논문에서는 FedAvg와 달리 Non-IID 환경에서도 수렴을 이론적으로 보장합니다.
  4. 통신 오버헤드 없음: SCAFFOLD와 같은 다른 알고리즘과 달리, 클라이언트는 오직 모델 가중치($w$)만 서버에 전송합니다. 추가적인 통신 비용이 발생하지 않습니다.
  5. 구현 용이성: 기존 FedAvg 코드에서 클라이언트의 로컬 손실 함수만 수정하면 되므로 구현이 매우 간단합니다.

단점

  1. 하이퍼파라미터 튜닝: $\mu$ 값을 데이터셋과 모델에 맞게 잘 조정해야 하는 추가적인 튜닝 부담이 있습니다.
  2. 약간의 연산 오버헤드: 로컬 학습 시 근접 항을 계산하고 그래디언트를 전파해야 하므로, FedAvg 대비 아주 약간의 연산량이 추가됩니다. (하지만 통신 오버헤드에 비하면 무시할 만한 수준입니다.)