연합학습/추론 성능

Non-IID 환경을 위한 병합 전략 네 가지

FedTensor 2025. 11. 16. 18:40

연합학습(Federated Learning) 환경에서 데이터가 이질적(Non-IID)일 때 발생하는 클라이언트 편향(Client Drift) 현상은 큰 문제입니다. 이는 각 클라이언트의 로컬 모델이 자신의 데이터에 과적합되어, 이를 단순 평균(FedAvg)할 경우 글로벌 모델의 성능이 저하되는 현상을 말합니다.

FedProx, FedFocal, SCAFFOLD, FedNova 는 이 문제를 각기 다른 방식으로 해결하려는 대표적인 병합(혹은 학습) 전략입니다.

FedProx (Federated Proximal)

FedProx는 클라이언트가 로컬 학습을 진행할 때, 글로벌 모델에서 너무 멀리 벗어나지 않도록 '제약'을 거는 방식입니다.

  • 핵심 아이디어: 로컬 손실 함수(Loss Function)에 '근접 항(Proximal Term)'을 추가합니다.
  • 동작 방식:
    1. 서버가 글로벌 모델($w_g$)을 클라이언트에 전송합니다.
    2. 클라이언트는 자신의 로컬 손실 함수에 패널티 항을 추가하여 로컬 모델($w_l$)을 학습시킵니다.
              > 새로운 로컬 손실 함수 = [기존 손실 함수] + $\frac{\mu}{2} \cdot \|w_l - w_g\|^2$
    3. 여기서 $\mu$ (뮤)는 하이퍼파라미터로, 이 값이 클수록 로컬 모델이 글로벌 모델에서 벗어나는 것을 강하게 억제합니다. 즉, 로컬 모델을 글로벌 모델에 '가까이' 둡니다.
  • 주요 해결 문제:
    • 통계적 이질성 (Statistical Heterogeneity): Non-IID 데이터로 인한 클라이언트 편향을 $\mu$ 항이 직접적으로 억제합니다.
    • 시스템 이질성 (Systems Heterogeneity): 일부 클라이언트는 1 에포크만, 다른 클라이언트는 10 에포크를 학습하는 등 연산량이 달라도 안정적으로 병합할 수 있도록 이론적 기반을 제공합니다.
  • 비용:
    • 통신 오버헤드: 없음 (기존 FedAvg와 동일하게 모델 가중치만 전송).
    • 서버/클라이언트 연산: $\mu$ 항 계산을 위한 약간의 추가 연산 (부담이 적음).

> 비유: FedProx는 로컬 모델이 "너무 멀리 산책 가지 못하도록" 짧은 목줄을 채우는 것과 같습니다.

FedFocal (Fed-Focal Loss)

FedFocal은 이미 충분히 학습된 쉬운 데이터(Easy Example)의 영향력은 줄이고, 판별이 어려운 데이터(Hard/Minority Example)에 집중하여 학습한다는 전략입니다.

  • 핵심 아이디어: 로컬 학습 시 손실 함수를 기존의 Cross-Entropy에서 Focal Loss로 교체합니다.
  • 동작 방식:
    1. 손실 함수 변경: 클라이언트는 서버로부터 받은 글로벌 모델을 업데이트할 때, 로컬 데이터에 대해 Focal Loss를 적용합니다.
      • Focal Loss 식: $FL(p_t) = -(1 - p_t)^\gamma \log(p_t)$
      • 예측 확률($p_t$)이 높은 샘플은 $(1-p_t)^\gamma$ 값에 의해 손실값이 급격히 작아져 모델 업데이트에 미치는 영향이 미미해집니다.
    2. 난이도 기반 업데이트: 로컬 학습 과정에서 모델이 틀리기 쉬운(확률이 낮은) 소수 클래스나 복잡한 샘플에 대해 더 큰 그래디언트(Gradient)를 생성하여 집중적으로 학습합니다.
    3. 가중치 병합: 학습된 로컬 모델들은 FedAvg와 동일한 방식으로 서버에서 평균화되어 글로벌 모델로 병합됩니다.
  • 주요 해결 문제:
    • 라벨 불균형(Label Skew): 특정 기관이 특정 질병(클래스) 데이터만 지나치게 많이 보유하거나, 전체 네트워크상에서 희귀 질환 데이터가 부족한 상황을 효과적으로 해결합니다.
    • 소수 클래스 성능 저하: 다수 데이터에 매몰되어 소수 데이터의 특징을 잡아내지 못하는 '롱테일(Long-tail)' 문제를 극복하여 전체적인 모델의 재현율(Recall)을 높입니다.
  • 비용:
    • 통신 오버헤드: 매우 낮음 (FedAvg와 동일한 수준).
    • 클라이언트 연산: 보통 (로컬 학습 시 손실 값을 계산하는 수식이 소폭 복잡해지지만, 딥러닝 연산 전체 비중으로 볼 때 무시할 수 있는 수준).
    • 서버 연산: 매우 낮음 (FedAvg와 동일한 수준).

> 비유: FedFocal은 시험 공부를 할 때 이미 100점을 맞는 쉬운 과목(Majority Class)에 시간을 쏟기보다는, 계속 틀리는 어려운 문제나 취약한 과목(Minority Class)에 오답 노트를 만들며 집중 투자하는 것과 같습니다.

SCAFFOLD (Stochastic Controlled Averaging)

SCAFFOLD는 각 클라이언트의 '편향된 방향'을 서버와 클라이언트가 모두 추적하여, 로컬 학습 시 이 편향을 보정하는 방식입니다.

  • 핵심 아이디어: '제어 변수(Control Variates)'를 사용하여 각 클라이언트의 그래디언트 편향(Client Drift)을 추정하고 보정합니다.
  • 동작 방식:
    1. 서버는 글로벌 제어 변수 ($c$)를, 각 클라이언트는 로컬 제어 변수 ($c_i$)를 가집니다. (이는 모델의 그래디언트 방향과 유사한 형태를 가집니다.)
    2. 클라이언트는 로컬 학습 시, 서버에서 받은 글로벌 제어 변수($c$)와 자신의 로컬 제어 변수($c_i$)의 차이를 이용해 그래디언트를 보정합니다.
              > 보정된 로컬 그래디언트 $\approx$ [로컬 그래디언트] - $c_i$ + $c$
    3. 이는 "당신의 로컬 데이터가 이끄는 편향된 방향($c_i$)을 빼고, 글로벌 모델이 가야 할 평균적인 방향($c$)을 더하라"는 의미입니다.
    4. 클라이언트는 학습 후 모델 업데이트($\Delta w_i$)와 제어 변수 업데이트($\Delta c_i$)를 모두 서버에 전송합니다.
    5. 서버는 이 둘을 모두 병합하여 글로벌 모델($w$)과 글로벌 제어 변수($c$)를 업데이트합니다.
  • 주요 해결 문제: 통계적 이질성(Client Drift)을 매우 직접적이고 정교하게 보정합니다.
  • 비용:
    • 통신 오버헤드: 매우 높음 (약 2배). 클라이언트는 모델 가중치뿐만 아니라 제어 변수($\Delta c_i$)도 함께 전송해야 합니다.
    • 서버 메모리: 서버가 모든 클라이언트의 로컬 제어 변수($c_i$)를 저장하고 있어야 하므로 메모리 부담이 큽니다. (수백만 대의 기기를 다루기엔 부적합)

> 비유: SCAFFOLD는 각 클라이언트에게 "당신의 나침반은 이만큼 고장 났으니(로컬 편향), 이 보정값을($c-c_i$) 적용해서" 올바른 방향(글로벌 방향)으로 가도록 별도의 보정 장치를 주는 것과 같습니다.

FedNova (Federated Normalized Averaging)

FedNova는 Client Drift 자체를 수정하기보다, 병합 과정의 '불공평함'을 해결하는 데 초점을 맞춥니다.

  • 핵심 아이디어: 시스템 이질성으로 인해 각 클라이언트가 수행한 로컬 학습 스텝(step) 수가 다를 때, 이들의 기여도를 '정규화(Normalize)'하여 공평하게 평균냅니다.
  • 동작 방식:
    1. FedAvg에서는 1 에포크 학습한 클라이언트의 업데이트와 10 에포크 학습한 클라이언트의 업데이트를 동일한 가중치(데이터 수 기준)로 평균냅니다. 10 에포크 학습한 모델이 훨씬 더 많이 이동했기 때문에 평균에 과도한 영향을 미칩니다.
    2. FedNova는 각 클라이언트가 '얼마나 많은 로컬 스텝을 수행했는지'($\tau_i$)를 서버에 함께 보고하도록 합니다.
    3. 서버는 클라이언트의 모델 업데이트($\Delta w_i$)를 로컬 스텝 수($\tau_i$)로 나누어 '단위 스텝당 평균 업데이트'를 계산합니다.
    4. 이 '정규화된 업데이트'를 가중 평균하여 글로벌 모델을 업데이트합니다.
  • 주요 해결 문제: 시스템 이질성 (Systems Heterogeneity)을 직접적으로 해결합니다. 시스템 이질성으로 인해 발생하는 불안정성을 제거함으로써, 간접적으로 통계적 이질성 문제 완화에도 도움을 줍니다.
  • 비용:
    • 통신 오버헤드: 거의 없음 (모델 가중치 외에 로컬 스텝 수($\tau_i$)라는 스칼라 값 하나만 추가 전송).
    • 서버 연산: 정규화를 위한 간단한 나눗셈 연산 (매우 가벼움).

> 비유: FedNova는 "어떤 클라이언트는 1보폭, 어떤 클라이언트는 10보폭 걸었다고 해서 10보폭 걸은 사람 말을 더 들어주지 말고", "한 보폭(1-step)당 평균적으로 어디로 갔는지"를 계산하여 공평하게 평균내는 것과 같습니다.

비교 요약표

특징 FedProx (근접 항) FedFocal (Focal Loss) SCAFFOLD (제어 변수)
FedNova (정규화 평균)
핵심 전략 로컬 학습 제약 (Client-Side) 어려운 데이터에 집중 (Client-Side) 클라이언트 편향 보정 (Client+Server-Side)
서버 병합 방식 수정 (Server-Side)
주요 해결 문제 통계적 이질성, 시스템 이질성 통계적 이질성 (소수 클래스 성능 저하) 통계적 이질성 (클라이언트 편향) 시스템 이질성
작동 위치 클라이언트 (손실 함수 수정) 클라이언트 (손실 함수 수정) 클라이언트 (그래디언트 보정) + 서버 (제어 변수 저장/병합)
서버 (병합 로직 수정)
통신 오버헤드 없음 (FedAvg와 동일) 없음 (FedAvg와 동일) 매우 높음 (약 2배) (모델 + 제어 변수 전송)
매우 낮음 (모델 + 로컬 스텝 수)
서버 메모리 낮음 낮음 매우 높음 (모든 클라이언트의 제어 변수 저장) 낮음
주요 신규 파라미터 $\mu$ (근접성 계수) $\gamma$ (집중도 계수) 없음 (대신 상태 값 $c_i$​ 저장) 없음

결론

  • FedProx는 구현이 간단하고 통신 오버헤드가 없으며, 통계적/시스템 이질성 모두에 준수한 성능을 보여 가장 실용적인 대안으로 많이 사용됩니다.
  • FedFocal은 구현이 간단하고 통신 오버헤드가 없으며, 소수 클래스와 같이 학습이 어려운 데이터에 집중하는 방식입니다.
  • SCAFFOLD는 Non-IID로 인한 클라이언트 편향을 보정하는 성능은 매우 뛰어나지만, 통신 및 서버 메모리 오버헤드가 매우 커서 대규모 환경(수천~수백만 클라이언트)에서는 비현실적일 수 있습니다.
  • FedNova는 특히 클라이언트들의 연산 능력이 크게 차이 나는(시스템 이질성) 환경에서 FedAvg보다 훨씬 안정적인 수렴을 보장하는 가벼운 해결책입니다.