새소식

자연어 NLP

[RLHF] BOND (1) : Aligning LLMs with Best-of-N Distillation

  • -


BOND: Aligning LLMs with Best-of-N Distillation

 

link : https://arxiv.org/abs/2407.14622

 

 

 

 

구글 딥마인드에서 새로운 RLHF method, J-BOND에 대해 소개한 논문이다. 해당 방법은 Gemma1.1 모델 학습시 사용되어 reward/KL trade-off 방식을 사용하는 강화학습 기반 baseline들에 비해 outperform하는 성능을 보여준다고 한다. J-BOND는 보상 분위수를 추정하기 위해 Monte Carlo 샘플링을 사용하여 Best-of-N 샘플링을 emulate하는 Best-of-N Distillation 알고리즘을 도입하였다.

 

알고리즘을 간단하게 요약하면 다음과 같다.

 

1. 프롬프트와 리워드모델을 수집

2. 각 prompt마다 현재 policy로부터 1개의 샘플을 생성하고, anchor $($reference$)$ policy로부터 2개의 샘플을 생성한다.

3. anchor 샘플을 사용해 forward KL gradient를 계산하고

4. policy 샘플과 rJ-BOND 리워드 함수를 사용해 backward KL gradient를 계산한다.

5. Jeffreys divergence + KL regularization이 합쳐진 combined gradient를 사용해 policy weight를 업데이트 한다.

6. anchor$($reference$)$ model을 EMA$($Exponential Moving Average$)$를 사용해 업데이트 한다.

 

기존 RLHF 방법들은 anchor$($reference$)$이 SFT로부터 초기화되어 이후 업데이트 되지 않지만, 해당 알고리즘은 anchor 모델을 천천히 업데이트하여 RLHF 학습시 안정성이 높아진다. 또한 anchor 모델이 policy 업데이트를 위한 moving target이 된다.

 

 

 

 Abstract

 

RLHF는 대형 언어 모델의 품질과 안전을 보장하는 핵심 방법이지만, 가장 좋은 N개를 샘플링하는 방법을 사용한다. 해당 논문에서는 추론 시 계산 오버헤드 없이 Best-of-N을 emulate하는 새로운 RLHF 알고리즘인 Best-of-N Distillation$($BOND$)$를 제안한다. BOND는 policy에 따른 generation 분포를 Best-of-N 분포에 가까워지도록 하는 distribution matching 알고리즘이다. mode-covering $($분포를 고르게 해서 다양한 모드를 커버하는 것$)$과 mode-seeking $($중심 모드 주위 고밀도 데이터를 더 많이 포착하는 것$)$ 사이의 균형을 맞추기 위해 Jeffrey divergence $($forward + backward KL의 평균$)$을 사용하고 효율성을 위해 moving anchor를 활용하는 iterative formulation을 도출한다.

 

 

 

 Introduction

 

Gemini, GPT4와 같은 SOTA LLM은 주로 3가지 stage를 통해 학습된다. 먼저, next-token prediction을 사용해 large corpora를 사전학습하고, SFT를 통해 instruction을 따르도록 파인튜닝한다. 마지막으로 RLHF를 사용해 generation의 퀄리티를 높인다. RLHF 스텝에서는 일반적으로는 reward model $($RM$)$을 human preference르 학습하고 LLM이 예측된 reward가 큰 쪽으로 aligning한다.

 

 RLHF algorithms and their challenges

 

RL로 LLM을 파인튜닝하는 것은 사전 학습된 지식을 잃어버릴 수 있고(참고), Reward hacking 문제가 발생할 수 있어 어렵다. 일반적인 전략은 policy-gradient method와 KL regularization을 사용하여 낮은 KL에서 높은 보상을 제공하는 Pareto-optimal 정책을 추구함으로써 원래 모델의 일반적인 기능을 보존하고 misalignment 문제를 해결한다.

 

 

 Best-of-N sampling

 

실제로는 generation 품질을 향상시키기 위해 간단한 inference-time approach인 Best-of-N 샘플링이 자주 사용된다. RLHF 방법들과는 다르게 Best-of-N은 LLM의 weight를 파인튜닝하지 않고 대신 inference procedure를 수정한다. inference-time에 reference model에서 여러개를 샘플하고 그 중 보상이 가장 높은 generation을 선택한다. 이는 reward/KL trade-off를 볼 때 경험적으로 효율적인 것으로 나타났으며(참고) Pareto-optimality 측면에서 theoretical gurantees(참고)가 제공된다. 하지만 N개의 generation을 샘플하는 것은 단일 샘플하는 것보다 cost가 배로 들기 때문에 선형적으로 증가하는 높은 추론 비용이 발생한다.

 

 

 

 BOND

 

 

 

본 논문에서는 BOND $($Best-of-N Distillation$)$ RLHF 알고리즘을 제안한다. 이는 Best-of-N 샘플링의 강력한 성능은 달성하되 기존 N개의 샘플링으로 소요되는 cost를 줄여 inference time에도 single sample만 필요로 한다. policy alignment를 distribution matching 문제로 바꾸고 Best-of-N 분포를 emulate하기 위한 policy로 파인튜닝 할 수 있다. 이를 위해 먼저 Best-of-N 분포에 대한 analytical expression을 도출한다. 이를 통해 다양한 divergence metrics를 고려하고 최적화한다.

 

첫번째로 Best-of-N 샘플링을 사용해 forward KL divergence를 최소화하는 방법을 보여주고 이는 mode covering behavior를 다루는 standard imitation learning으로 이어진다.

 

두번째로 backward KL을 최소화 하는 방법을 보여주는데, 이는 reward scale에 의존하지 않고 mode seeking behavior에 해당하는 새로운 형태의 quantile-based advantage로 이어진다.

 

마지막으로 Jeffreys divergence로 알려진 forward, backward KL을 함께 최소화하는 것을 제안한다. 이는 위 두가지 접근 방식의 장점을 모두 유지할 수 있다. 그리고 sample-complexity를 줄이면서 성능을 최적화 하기 위해 moving anchor policy의 Best-of-N을 반복적으로 distillation하는 iterative BOND approach를 소개한다. 마지막으로 이 제안한 아이디어들을 종합하여 J-BOND 라는 practical RLHF 알고리즘을 제안한다.

 

 

 The BOND Approach

 

 1. The Best-of-N distribution

 

Best-of-N sampling의 exact analytical distribution을 도출하고 이 분포의 속성에 대해 살펴본다. 

 

 

Theorem 1은 Best-of-N sampling의 직관적인 설명을 위한 것이다. 원래 샘플링 분포 $\pi_{ref}(y)$가 있을 때 나머지 N-1개의 샘플이 y 샘플보다 안 좋을 확률에 대한 텀 (A), (B)를 곱해줌으로써 reweight하는 효과를 준 것이다. 

 

 

 

먼저 $\pi_{ref}$에서 N개의 generations를 랜덤 샘플하여 $y_1, y_2, .. y_N$ 가 있고 $y$는 이 중 하나라고 생각하자.

$A_i(y)$는 $y$가 best sample일 사건$($event$)$이고 i는 $y_i=y$인 가장 작은 인덱스이다. 

 

2개를 샘플하여 $r(y_1) = 2, r(y_2) = 3$ 이라면 $A_2(y_2=y)$가 되는 것이고 $r(y_1) = 20, r(y_2) = 3$ 이라면 $A_1(y_1=y)$가 되는 것이다. 5개를 샘플 했을 때 $r(y_1)=2, r(y_2)=3, r(y_3)=5, r(y_4)=5, r(y_5)=5$ 라면, $A_3(y_3=3)$이다.

 

i가 아닌 index j에 대해 i번째 샘플이 best sample이라면 j번째 샘플이 동시에 best sample일 순 없다. 따라서 ${A_i(y)} \ i=1, 2, .. N$ 각 사건은 disjoint이다.

 

event $A_i(y)$는

1. $r(y_j) < r(y), j < i $

2. $y_i=y $

3. $r(y_j) <= r(y), j >= i$

3가지 조건에 필요충분 조건이다.

 

그러므로 event $A_i(y)$의 likelihood는 위 수식처럼 표현할 수 있다. i보다 작은 index를 가진 sample들의 보상이 $r(y_i)$보다 작고, i보다 큰 index를 가진 sample들의 보상이 $r(y_i)$보다 작거나 같은 확률들은 서로 독립이므로 개별 확률들의 곱으로 표현하고 y자체의 prob을 곱한 형태로 나타내었다.

 

 

 

그리고 이 event $A_i(y)$의 likelihood의 Union은 Best-of-N sampling에 선택되는 y에 해당한다. 3 -> 4번째 줄에서 $p_{<=}(y)^{N-1}$로 곱하고 나눠준 뒤 summation과 관련 없는 term을 앞으로 빼면 위와 같은 Best-of-N sampling으로 y가 선택될 likelihood 식이 도출된다. 

 

 

 

(A)는 동일 프롬프트에 대해 y보다 나쁘거나 같은 generation의 비율을 기반으로 하는 penalty exponential에 해당한다. 직관적으로 이는 N을 늘릴 때 bad generation에서 샘플링하는 횟수가 exponentially 줄어드는 것을 보장한다.

 

(B)는 generation간 충돌 가능성에 따른 추가 correction factor 이다. 이는 항상 [1, N] 사이로 제한되기 때문에 기껏해야 선형적이라는 것이다. 이는 정의에 따라 worst generation y-에 대해 최솟값인 1이 된다. 정확히 연속적으로 N번 샘플링 해야하고 $\pi_{BoN}(y\text{-})=\pi_{ref}(y\text{-})^N$가 된다. 반면에 개별 y의 likelihood가 낮고, 이러한 y가 좋다면 $p_{<}(y)$는 거의 $p_{<=}(y)$가 되어 (B)는 N에 가까워진다. 직관적으로 이는 한 generation을 여러 번 샘플링할 가능성이 거의 없는 경우에 해당한다. 

 

 

 2. The BOND objective

 

저자들은 Best-of-N distribution의 analytical 특성을 통해 BOND를 distribution-matching 문제로 formulation 하기 위한 objective를 아래와 같이 제안한다.

 

 

 

$D(.||.)$은 $\pi_{BoN}$ 쪽으로 training policy $\pi$를 조정하는 divergence metric이다. online, offline sample을 통해 D를 추정할 수 있고 이를 위한 적절한 divergence와 그에 따른 BOND 알고리즘을 선택하는 것은 section 4에 있다.

 

 

 3. Connection with standard RLHF

 

DPO에서 한 것처럼 BOND도 제안한 objective $($6$)$에 대해 standard RLHF 수식과의 연결성을 보였다. 

 

 

먼저 잘 알려진 policy maximizing RLHF objective는 $($7$)$과 같다. referece policy와 exponential $\beta$로 scale된 reward에 비례하는 policy이다. Theorem 1의 $\pi_{BoN}$ 수식에서 다음과 같은 특정 BOND reward를 사용할 때 Best-of-N 샘플링 분포가 standard RLHF optimal solution과 일치함을 알 수 있다. $($7$)$의 $r(y)$에 $\pi_{BoN}$의 reweight term인 $($A$)$와 $($B$)$를 사용하면 RLHF objective와 같다고 볼 수 있는 것이다. 그리고 $\beta$는 $\frac{1}{N-1}$이 된다. correction factor인 term $($B$)$는 $[0, \log \frac{N}{N-1}]$로 바운드 되어있지만 (A)는 (-inf, 0]이다. 

 

이는 Best-of-N sampling의 2가지 interesting insights를 제공한다. 

 

1. Best-of-N 샘플링은 N의 선택에 따라 KL regularization 수준이 결정되는 standard KL-regularized RLHF 문제의 솔루션에 해당한다.

 

2. Best-of-N 샘플링은 예상되는 log reward quantile $($즉 reference 분포로부터 랜덤 생성된 샘플보다 더 큰 보상을 받을 log likelihood$)$을 최적화 하는 것과 같다. 흥미롭게도 log의 concavity 때문에 $r_{BOND}(y)$는 모델의 good ones를 생성하는 것을 장려하는 것 보다 bad generations를 강력히 피하게 한다. 또한 $r_{BOND}(y)$는 generations간 순위에만 의존하기 때문에 보상 $r(.)$의 invariant to monotone transformations에 불과하다. 저자들은 이 두 가지 기능이 $r_{BOND}(y)$가 standard RLHF에 비해 reward hacking에 더 robust하게 만든다고 추측한다.

 

이러한 RLHF와 BOND의 connection은 또한 해당 논문에서 제안된 방법에 대해 영감을 준다.우리가 BOND reward 또는 동일하게 Best-of-N distribution $\pi_{BoN}$을 계산할 수 있다면 distribution matching을 통해 Best-of-N 방향으로 policy를 조정할 수 있다는 것을 의미한다. 따라서 4 section에서는 이러한 과제를 해결하기 위한 다양한 알고리즘에 대해 살펴본다.

 

 

 

 BOND Challenges and Algorithms

 

앞에서 살펴본 BOND reward, $\pi_{BoN}$을 구현하기 위해서는 다음과 같은 3가지 과제가 발생한다.

 

1. how to estimate the reward quantiles

2. which is the appropriate divergence metric touse

3. how to choose the hyperparameter $N$

 

 

 1. Monte-Carlo quantile estimation

 

 

$\pi_{BoN}$ 분포를 추정하기 어려운 이유 중 하나는 generation $y$에 대한 quantile $p_{<=}(y)$을 추정해야 하기 때문이다. quantile $p_{<=}(y)$는 같은 프롬프트 $x$가 주어질 때 $\pi_{ref}$에서 생성된 샘플들과 비교해 y의 품질을 측정한다 즉, y가 얼마나 높은 순위를 가졌는지 나타내는 지표다. $($notation에서는 $x$가 모두 편의상 생략되었다.$)$

 

 

 

매우 간단하지만 효과적인 quantile estimation 방법 중 하나는 Monte-Carlo sampling이다. k개의 샘플을 생성하여 위와 같이 empirical estimate를 할 수 있다. Monte carlo sampling은 분포나 기댓값, 분산 등을 추정할 때 해를 구하는 것이 아니고 랜덤 샘플로 근사하는 방법이다. $($10$)$ 수식을 보면 여기서 각 sample에 대해 $r(y') <= r(y)$인지 확인하고 이를 만족하는 샘플의 수를 세고$($만족하면 1 그렇지 않으면 0$)$ 그 수를 전체 샘플의 k로 나눈 것이다.

 

그리고 저자들은 제한된 수의 샘플을 사용하여 해당 방법이 실험에서 매우 효과적이라는 것을 관찰했다. $($Appendix B.1에 자세히 있음$)$ 그러나 원칙적으로는 학습된 quantile model과 같은 대체 접근 방식을 사용할 수도 있다. 

 

 

 2. Jeffreys divergence as a robust objective

 

그리고 BOND에 사용될 divergence metric을 선택하는 것은 매우 중요한 문제이다. 서로 다른 divergences는 매우 다른 방향으로 policy를 조정할 수 있기 때문이다. BOND에서는 robust 한 distribution matching objective로 Jeffreys divergence를 제안한다.

 

 

 

$($generalized$)$ Jeffreys divergence는 forward/backward KL divergence의 beta weighted average이다.

policy $p$를 파인튜닝할 때 forward $KL(q || p)$는 $q$에 있을 가능성이 높은 generations 또한 $p$에 있을 가능성이 높다고 장려하므로 mode-covering behavior를 장려하게 된다. 

반면, reverse $KL(p || q)$는 q에 high likelihood를 가지는 generation을 생성하도록 policy $p$를 조정하는 mode-seeking 효과를 준다고 알려져 있다. (참고)

forward KL이 over-spread 분포를 생성하게 될 수 있는 반면 backward KL은 policy와 entropy collapses를 초래할 수 있다. 저자들은 이 두 분포의 장점을 모두 가져가는 Jeffreys divergence를 사용하는 것이 더 좋은 policy를 낸다는 것을 empirically 보여준다.

 

BOND의 맥락에서 이는 다음과 같이 reference policy와 training policy의 샘플을 사용해여 추정할 수 있는 Jeffreys divergence를 최소화하는 것으로 해석된다.

 

 

 

먼저 forward KL은 $\pi_{BoN}$ $($N번 샘플하고 best one selecting$)$로 생성된 샘플로 직접적으로 추정할 수 있고 Best-of-N samples에 대한 SFT loss로 볼 수 있다.

 

 

 

backward KL의 경우 policy samples $($즉 $\pi에 대한 기댓값$)$과 $\pi_{BoN}$의 log-likelihood로부터 추정될 수 있다. 그리고 이 gradient는 policy gradient $($REINFORCE$)$와 일치한다는 것을 아래 Appendix에서 보여준다.

 

 

 

$\pi$와 $\pi_{BoN}$의 KL gradient의 expectation을 $\pi$에서 샘플된 y로부터 구하기 위해 expectation을 summation으로 바꾸고 $($2번째 줄$)$, $\pi(y)$와 $(\log \pi(y) - \log \pi_{BoN} (y))$ 두 부분에 대한 미분으로 나눈다$($3번째 줄$)$.

 

그리고 product rule을 적용해 4번째 줄과 같이 전개한다.

product rule 참고

 

다시 Expectation 형태로 바꾼뒤 $($5번째 줄$)$ 맨 마지막 term은 확률 분포의 기대값 정의에 따라 0이므로 $( \sum_y \pi(y) = 1, \nabla_\pi 1 = 0)$ 생략하면 맨 아래와 같은 수식으로 표현할 수 있다.

 

REINFORCE Policy Gradient

 

예상한 대로 위 KL gradient를 descending하는 것은 $r = r_{BOND}, \beta_{RL} = \beta_{BOND}$인 Equation 1의 RL objective에 대해 일정 스케일까지 RL policy gradient REINFORCE 알고리즘을 실행하는 것과 동일하다는 것을 확인할 수 있다. $R(y)$를 $\log\pi(y) - \log\pi_{BoN}(y)$로 보면 된다.

 

 

 

그리고 $\pi_{BoN}$ expression을 사용해 위 gradient를 위와 같이 분해할 수 있다.

 

 

 

위에서 구한 $\pi_{BoN}$을 gradient 전개식에 대입하여 2번째 줄로 표현하고 $r_{BOND}$와 $\beta_{BOND}$로 표현하면 $($15$)$식과 같이 전개할 수 있다. 이렇게 되면 REINFORCE에서 $R(y)$가 reward가 $r_{BOND}$와 regularization $\beta_{BOND}$와 같아진다. $r_{BOND}$는 unknown true quantile $p_{<=}(y)$과 correction factor $($B$)$에 의존한다. 실제 구현은 quantile을 estimation으로 대체 했지만 correction factor는 크게 중요한 역할을 하지 않는 것으로 관찰됐다. 그리고 variance를 줄이기 위해 배치의 generations에 대한 average return을 계산한 policy gradient baseline을 사용했다.

 

또한, 제안된 (11)식의 loss $J_{effreys}^\beta$는 SFT와 policy gradient loss의 linear weighted combination이다.

 

 

위에서 쭉 언급된 저자들의 intuition을 확인한 실험이다. summarization XSum task에서 T5모델을 사용하여 제안한 loss를 적용하고 beta를 0, 0.5, 1을 적용하였다. 각 프롬프트마다 quantile을 estimate하기 위해 16 MC samples을 사용하였고 500 training steps마다 eval하였는데, 이때는 32 MC samples를 사용하여 backward, forward KL을 estimate하였다. 위 결과는 N=8로 설정한 결과이고 4와 16의 결과는 Appendix B.2에 있다. 

 

Jeffreys divergence $($beta=0.5$)$를 사용하면 backward $($beta=1$)$ 또는 forward $($beta=0$)$ KL 만 최소화 하는 경우보다 $\pi_{BoN}$ $($왼쪽과 가운데 plot$)$에서 두 divergence를 모두 최소화 할 수 있음을 실험적으로 확인했다.

 

가장 오른쪽 플롯은 eval batch에서 평균한 training policy의 reward log quantiles를 나타낸 것이다. 흥미롭게도 BOND beta=0.5는 mode-seeking$($reverse KL$)$ beta=1 과 유사하게 quantiles를 maximize하는 반면 mode-covering$($forward KL$)$ beta=0은 뒤쳐지는 것을 볼 수 있다.

 

 


 

 

 

내용이 너무 길어져서 (2)편으로 나눕니다.

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.