연합학습/추론 성능

Non-IID 환경을 위한 병합 전략: SCAFFOLD

FedTensor 2025. 11. 16. 20:02

SCAFFOLD (Stochastic Controlled Averaging)는 연합학습(FL)에서 데이터 이질성(Non-IID)으로 인해 발생하는 'Client Drift (클라이언트 편향)' 문제를 해결하기 위한 매우 정교한 알고리즘입니다.

FedProx가 로컬 모델이 멀리 벗어나는 것을 '억제'하는 방식이라면, SCAFFOLD는 각 클라이언트가 얼마나 편향되었는지를 '추정'하고 이를 '보정'하는, 더 적극적인 방식을 사용합니다.

1. SCAFFOLD가 해결하려는 핵심 문제: 편향된 그래디언트

데이터가 이질적(Non-IID)일 때, 각 클라이언트가 계산하는 로컬 그래디언트(모델이 나아가야 할 방향)는 글로벌 모델이 실제로 나아가야 할 방향(모든 데이터의 평균 방향)과 다릅니다.

  • Client Drift의 원인: FedAvg에서는 이렇게 편향된 방향으로 로컬에서 여러 스텝을 이동합니다. 이 편향이 누적되면서 로컬 모델이 글로벌 모델에서 멀리 벗어나게 됩니다.
  • SCAFFOLD의 접근: "각 클라이언트의 로컬 그래디언트가 '얼마나', '어느 방향으로' 편향되었는지 추적하자. 그리고 로컬 학습 시 이 편향을 강제로 보정하자."

2. 핵심 아이디어: 제어 변수 (Control Variates)

SCAFFOLD는 이 '편향'을 추정하고 보정하기 위해 '제어 변수(Control Variates)'라는 장치를 도입합니다. 이는 모델 가중치($w$)와 동일한 형태(shape)를 가집니다.

  • 서버: 글로벌 제어 변수 ($c$)를 가집니다. 이는 '이상적인 글로벌 그래디언트 방향'을 추정합니다.
  • 클라이언트: 로컬 제어 변수 ($c_i$)를 가집니다. 이는 '클라이언트 $i$의 로컬 데이터가 만드는 평균적인 그래디언트 방향(즉, 편향의 방향)'을 추정합니다.

3. SCAFFOLD의 동작 원리 (라운드별 단계)

클라이언트 제어 변수($c_i$)는 한 번에 계산되는 것이 아니라, 매 라운드마다 로컬 학습이 완료된 후 갱신됩니다.

$c_i$는 클라이언트 $i$의 고유한 데이터 분포로 인해 발생하는 '평균적인 로컬 업데이트 방향' 또는 '로컬 드리프트'를 추정하는 값이며, 모델 파라미터($w$)와 동일한 형태를 가집니다.

클라이언트 $i$가 라운드 $t$에서 자신의 제어 변수 $c_i$를 갱신하는 과정은 다음과 같습니다.

1단계: 서버로부터 정보 수신

클라이언트 $i$는 서버로부터 다음 두 가지를 수신합니다.

  1. 글로벌 모델 ($w_t$): 현재 라운드의 글로벌 모델 파라미터
  2. 서버 제어 변수 ($c_t$): 현재 라운드의 글로벌 제어 변수

또한, 클라이언트 $i$는 이전 라운드(t-1)에서 계산하여 로컬에 저장해 둔 자신의 이전 클라이언트 제어 변수 ($c_{i, t-1}$)를 가지고 있습니다. (가장 첫 라운드에는 0으로 초기화됩니다.)

2단계: 보정된 로컬 학습

클라이언트는 $K$번의 로컬 스텝 동안 학습을 진행합니다.

  • (1 에포크 당 스텝 수) = (로컬 데이터 총 개수 $N_i$) / (배치 크기 $B$)
  • $K$ (총 로컬 스텝 수) = $E$ (에포크) $\times$ (1 에포크 당 스텝 수)
  • 즉, $K = E \times \frac{N_i}{B}$ 입니다. 이것은 $K$번의 '미니배치(mini-batch) 학습'이 일어났다는 것을 의미하며, 총 $K$번의 그래디언트 업데이트가 발생했음을 나타냅니다.

이때, FedAvg처럼 단순히 로컬 그래디언트 $g_i(w)$를 사용하는 것이 아니라, 보정된 그래디언트 ($g'_i$)를 사용합니다.

$$g'_i = g_i(w) - c_{i, t-1} + c_t$$

  • $g_i(w)$: 현재 로컬 데이터 배치가 말하는 업데이트 방향
  • $- c_{i, t-1}$: 나의 지난 라운드까지의 '평균 편향'을 빼서 보정
  • $+ c_t$: 서버가 원하는 '글로벌 평균 방향'을 더해서 보정

이 보정된 그래디언트 $g'_i$를 사용하여 $\eta_l$(로컬 학습률)만큼 $K$번 모델을 업데이트합니다.
그 결과 새로운 로컬 모델 ($w_{i, t}$)이 완성됩니다.

3단계: 새로운 $c_i$ 계산

로컬 학습이 완료된 후, 클라이언트는 다음 라운드($t+1$)에 사용할 새로운 클라이언트 제어 변수 ($c_{i, t}$)를 계산합니다.

SCAFFOLD 논문에 따른 $c_{i, t}$의 갱신 공식은 다음과 같습니다.

$$c_{i, t} = c_{i, t-1} - c_t + \frac{1}{K \cdot \eta_l} (w_t - w_{i, t})$$

이 공식의 각 항목은 다음과 같습니다.

  • $c_{i, t}$: (이번에 계산할) 새로운 클라이언트 제어 변수
  • $c_{i, t-1}$: (로컬에 저장되어 있던) 이전 클라이언트 제어 변수
  • $c_t$: (서버에서 받은) 서버 제어 변수
  • $K \cdot \eta_l$: 로컬 스텝 수($K$)와 로컬 학습률($\eta_l$)의 곱
  • $(w_t - w_{i, t})$: (서버에서 받은) 글로벌 모델과 (내가 학습한) 새 로컬 모델의 차이

이 공식의 의미

이 복잡한 식은 사실상 "이번 로컬 학습 $K$ 스텝 동안 관찰된 '순수 로컬 그래디언트의 평균값'"을 효율적으로 계산한 것입니다.

즉, $c_{i, t} \approx \frac{1}{K} \sum_{k=1}^{K} g_i(w_k)$ 와 같습니다.

$c_{i, t}$는 클라이언트 $i$의 데이터가 '진짜' 원하는 방향(편향)을 나타내며, 이 값을 계산하여 로컬에 저장해 둡니다.

4단계: 서버로 전송

클라이언트는 서버로 다음 두 가지 정보를 전송합니다.

1.  새로운 로컬 모델 ($w_{i, t}$) (또는 모델 업데이트 값 $\Delta w_i = w_{i, t} - w_t$)
2.  제어 변수 '업데이트' 값 ($\Delta c_i = c_{i, t} - c_{i, t-1}$)

서버는 $\Delta c_i$ 값들을 평균 내어 $c_t$를 갱신하는 데 사용하고, 클라이언트는 방금 3단계에서 계산한 $c_{i, t}$ 값을 다음 라운드를 위해 로컬에 저장합니다.

4. 장점 및 단점 요약

장점

  1. 뛰어난 Non-IID 성능: Client Drift의 원인인 '편향된 그래디언트'를 매우 직접적이고 정교하게 보정합니다. 따라서 데이터 이질성이 매우 심한 환경에서도 빠르고 안정적으로 수렴합니다.
  2. 이론적 기반: FedAvg나 FedProx보다 더 빠른 수렴 속도를 이론적으로 증명합니다.
  3. 공평성(Fairness) 향상: 특정 클라이언트의 데이터 편향이 글로벌 모델에 과도한 영향을 미치는 것을 막아, 모델의 공평성을 높이는 데 기여할 수 있습니다.

단점

  1. 높은 통신 오버헤드 (치명적 단점):
    • 클라이언트는 서버에 모델 업데이트($\Delta w_i$)와 제어 변수 업데이트($\Delta c_i$)를 모두 전송해야 합니다.
    • 모델 가중치와 동일한 크기의 데이터가 추가로 전송되어야 하므로, 통신량이 FedAvg/FedProx 대비 약 2배가 됩니다.
  2. 높은 서버 메모리 오버헤드:
    • SCAFFOLD는 서버가 모든 클라이언트의 로컬 제어 변수($c_i$)를 개별적으로 저장하고 있어야 합니다. (다음 라운드에 참여하지 않더라도 유지해야 함)
    • 클라이언트 수가 수천, 수백만 대(Cross-device FL)인 환경에서는 사실상 사용이 불가능합니다. (주로 클라이언트 수가 고정된 Cross-silo FL 환경을 가정합니다.)
  3. 구현 복잡성: FedAvg나 FedProx에 비해 구현이 훨씬 복잡합니다.