새소식

자연어 NLP

[RLHF] BOND (2) : Aligning LLMs with Best-of-N Distillation

  • -


BOND: Aligning LLMs with Best-of-N Distillation

 

link : https://arxiv.org/abs/2407.14622

 

(1)편

 

[RLHF] BOND: Aligning LLMs with Best-of-N Distillation (1)

BOND: Aligning LLMs with Best-of-N Distillation link : https://arxiv.org/abs/2407.14622    구글 딥마인드에서 새로운 RLHF method, J-BOND에 대해 소개한 논문이다. 해당 방법은 Gemma1.1 모델 학습시 사용되어 reward/KL tra

ebbnflow.tistory.com

 

 

 

‣ 3. Iterative BOND

 

최종적으로 parameter N을 어떻게 선택할 것인지에 대해 논의한다. 실제로 N을 결정하는 것은 아래 3가지 이유로 어려운 일이다.

 

(1) standard RLHF에서 N은 regularization 역할을 하므로 큰 N은 downstream 성능을 향상한다. 하지만 N이 너무 크다면 궁극적으로 이는 reward over optimzation을 초래할 수 있다.

(2) 큰 N일 수록 $\pi_{BoN}$의 추정치는 추정 quantile의 오류에 더 민감하다. ( $ \because \pi_{BoN}(y) \propto p_{\le}(y)^{N-1}$)

(3) forward KL을 추정하는 것은 $\pi_{BoN}$으로부터의 샘플링이 필요하다. 하지만 이는 큰 N에 대해서는 어렵다.

 

 

위와 같은 문제점을 해결하기 위해, iterative BOND라는 방법을 제안한다. 이 접근법은 Best-of-N 분포로부터 Best-of-N 샘플링을 하는 것은 원래의 분포로부터 Best-of-N 샘플링을 제곱번하는 것과 같다는데서 착안했다. $BoN(\cdot)$을 base distribution으로부터 Best-of-N 샘플링을 수행하는 operation으로 정의한다:

 

 

 

이는 iterative BOND의 핵심 아이디어를 제안한다: 만약 우리가 어떻게 Best-of-N 분포를 distill 해야하는지 안다면, 우리는 BOND를 재귀적으로 적용할 수 있다. 이는 초기 분포 $\pi_{ref}$ 의 Best-of-N$^M$을 distilling 하는 것과 같다.

 

 

이 개선된 operator는 n을 2와 같은 작은 사이즈로 고정하고 BOND를 반복적으로 수행할 수 있게 한다. 이를 위해 auxiliary anchor policy $\pi_{anchor}$를 소개한다. 이는 초기 분포인 $\pi_{ref}$가 업데이트된 분포이다. $\pi_{anchor}$에 대해 BOND를 수행할 수 있게 되고 이는 $\pi_{anchor}$의 Best-of-n version 을 distill할 수 있다는 의미다. 주어진 distillation 단계 후 $\pi_{anchor}$를 현재 policy $\pi_{t}$로 업데이트한다.

 

간단히 말해, iterative BOND는 sample complexity를 줄이고 안정적인 optimization을 유지하면서 임의의 큰 N을 exponetial scaling한다.

 

 

 

이에 대한 실험에 관한 Figure이다. Figure2에 있는 실험과 같은 셋팅을 하고 moving anchor 단계만 매 1000steps마다 시행한 점이 다르다. 왼쪽 플롯은 average reward와 중간 플롯은 average log quantile 을 time steps마다 어떻게 바뀌는지 나타내고 있다. iterative 방법을 적용하지 않은 non-iterative BOND의 경우 두 reward signal이 모두 일찍 saturate되는 반면 (특히 N이 작을 수록 이러한 현상이 심화됨), iterative BOND의 경우 지속적으로 성능을 향상시키는 것을 볼 수 있다 (N이 클 수록 더 빠르게 성능이 향상됨). 가장 오른쪽 플롯은 $\pi_{ref}$와 $\pi$에 대한 KL에 따라 log quantiles가 어떻게 변하는지 나타낸 것이다. 이 플롯은 iterative BOND가 non-iterative와 같은 reward/KL trade-off를 가지고 있지만, 더 작은 n을 사용해 점진적으로 $\pi_{ref}$와 멀어질 수 있음을 보여준다.

 

 

 

✲ The J-BOND Algorithm

 

최종적으로 BOND의 practical algorithm을 보여준다. 

 

 

알고리즘의 수도코드는 위와 같다. J-BOND는 Algorithm1의 iterative BOND 템플릿을 따르고 있으며 n=2로 설정하였다. 따라서 $\pi_{ref}$로 초기화된 moving anchor $\pi_{anchor}^t$의 Best-of-2 버전을 iteratively distill하여 policy $\pi_{t}$를 파인튜닝한다. J가 붙은 이유는 Jeffrey divergence를 distribution matching objective에 사용했기 때문이다. (i.e., it minimizeds $J^{\beta}_{effreys} (\pi || \text{Best-of-2}(\pi^{t}_{anchor}))$)

 

 

 

J-BOND의 main components의 설명은 아래와 같다.

 

Minimal sample complexity

이전 섹션에서의 BOND 알고리즘과 비교해 J-BOND는 minimal sample complexity를 가지고 있다 : 배치의 각 프롬프트에서 $\pi_{t}$에서 1개의 샘플 $\pi^{t}_{anchor}$에서 2개의 샘플을 샘플한다. 일반적으로 더 많은 anchor sample이 divergence estimation에 더 유용하지만, autoregressive sampling은 online RLHF의 병목의 주 요인이다. 따라서 작은 샘플을 사용하는 실용적은 접근 방식을 선택하였다.

 

Crude divergence estimate based on 2 anchor samples

policy, anchor 샘플은 $J^{\beta}_{effreys} (\pi || \text{Best-of-2}(\pi^{t}_{anchor}))$에서 forward, backward KL을 crude estimate하기 위해 사용된다.

우리는 2개의 best anchor samples에 대해 SFT를 수행함으로써 forward KL을 minimze할 수 있으며,
Equation (15)의 $r_{BOND}(y)$를 $r_{J-BOND}(y)$로 대치한 policy gradient-style loss를 사용해 backward KL을 minimize할 수 있다. 그 이유는 오직 2개의 anchor samples만 이용할 수 있는 경우, $r_{BOND}(y) = \log \hat{p}_{\le}(y)$는 $\hat{p}_{\le}(y)$가 매우 noisy한 MC estimate이기 때문에 꽤나 uninformative할 수 있기 때문이다. $y$를 policy sample로,{ $y_{1}', y_{2}'$ }
를 anchor samples로 두고 대신 $r_{J-BOND}(y)$를 아래와 같이 정의한다.

 

 

즉 generation $y$는 2 anchor samples보다 더 나쁜 보상을 받으면 $- \log (16)$의 음수 보상을 받고 그렇지 않으면 0을 받는다. 

 

True reward function $r_{BOND}(y) = \log p_{\le}(\cdot)$ 에서 $p_{\le}(\cdot)$ 을 알기 위해서는 $\pi^{t}_{anchor}$의 reward distribution을 알아야하기 때문에 구할 수 없다. 따라서 이를 approximation하고 이 값을 $r_{J-BOND}(\cdot)$라 하였다. 이는 식 (17)과 같이 디자인하였다. 두 앵커샘플보다 나쁜 경우가 아닌 중간 case에 대해 보상을 부여했을때에 대한 이익을 관찰하지 못했기 때문에,

 

(1) 우리는 이상적인 unknown reward function $r_{BOND}=\log p_{\le}(\cdot)$의 concavity를 모방하기 위해 두 anchor samples보다 나쁜 경우에만 부정적인 보상을 부여한다.

 

그리고,

(2) $- \log (16)$라는 값을 설정한 이유는 : $p_{le}(y) = 0.5$ 일때 $\mathbb{E}_{y_1', y_2' \sim \pi^{t}_{anchor}} [r_{J-BOND}(y)] = \log p_{\le}(y)$ 가 되게 하기 위함이다.  $- \log (16)$는 $r_{J-BOND}(\cdot)$을 보정하여 2개의 앵커샘플에 대한 expectation에서 median reward를 받는 generation y에 대한 ideal reward $\log p_{\le}(\cdot)$와 일치한다. (i.ei, when $p_{\le}=0.5$

 

*위 내용에 대한 Derivation(Appendix A.4)

 

anchor rewards distribution과 비교했을 때 sample y가 median reward를 가지도록 하고자 했다 (i.e., $\p_{\le}(y) = 0.5$). 그러면 true reward $r_{BOND}(y) = \log p_{\le}(\cdot)=\log (0.5)$ 가 되어 $\r_{J-BOND}(y)$와 일치하게 된다. 이를 위한 $\r_{J-BOND}(y)$의 값을 $\alpha$로 파라미터화 한 함수와 이를 구하는 과정은 아래와 같다.

 

 

 

 

 

 

Exponential Moving Average (EMA) anchor

EMA를 통해 anchor policy를 업데이트 하는 것도 J-BOND의 중요한 컴포넌트 중 하나이다. 주기적으로 anchor를 업데이트하는 대신, anchor weight $\theta^{t}_{anchor}$를 각 파인튜닝 단계에서 policy weight $\theta^{t}$ 의 moving average로 업데이트한다.

 

 

 

이 weight average procedure는 variance를 줄이고 학습을 안정적으로 만드는 긍정적인 효과임을 실험적으로 관찰할 수 있었다. 또한 이는 전반적인 J-BOND의 reward/KL trade-off를 향상시켰다.

 

 

Additional KL regularization

policy를 moving anchor와 더 가깝게 머물게 하기 위해 추가적인 KL term을 사용했다. 이는 policy update를 더 안정적으로 만들고 conatrained optimization의 하나로써 overall opterator를 바라보게 한다.

 

 

 

 

 

✲ Experiments

 

실험은 J-BOND에서 중요한 측면들, EMA anchor, anchor speed 효과, additional KL term 이점을 확인하는 실험이 있다. 그리고 REINFORCE 와 비교해 J-BOND가 기존 RLHF 방법들에 비해 효율적이고 좋은 성능을 낸다는 것을 보여준다.

 

 

 

 

 

 

 


 

 

BOND는 Best-of-N sampling 분포를 online distillation 함으로써 모델을 파인튜닝하는 RLHF 알고리즘이다. Best-of-N sampling 기법을 inference 단계가 아닌 training 단계에 적용함으로써 inference cost를 없이 효율적으로 generation 성능을 높일 수 있다. training 단계에 Best-of-N distillation을 위해 Monte-Carlo quantile estimation 기법과 forward/backward KL의 장점만을 혼합하는 Jeffrey divergence를 objective에 도입하였다. 또한 N의 사이즈를 키울 때 발생하는 문제를 iterative procedure로 해결하고, EMA를 사용해 reference policy를 점진적으로 업데이트하여 practical한 알고리즘을 제시하였다. 기존 RLHF나 RL 파인튜닝 기법들은 reference policy를 업데이트하지 않기 때문에 학습의 불안정성을 야기하거나, pretraining 지식을 잊어버릴 수 있는 점, 리워드 해킹과 같은 문제가 있다. J-BOND는 이와 같은 단점들을 해결하면서도 practicality와 efficiency를 동시에 잡은 알고리즘이다.

 

 

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.