새소식

자연어 NLP

[RLHF] dDPO (Zephyr) : Direct Distillation of LM Alignment

  • -

ZEPHYR: DIRECT DISTILLATION OF LM ALIGNMENT

 

 

link : https://arxiv.org/pdf/2310.16944

 

 

 

✲ Abstract

 

해당 논문의 목적은 더 작은 언어 모델을 user intent에 맞게 align 하는 것이다. 이전 방법들은 distilled supervised fine-tuning (dSFT)를 사용해 큰 모델의 task accuracy를 향상시켰다. 하지만 이러한 방법은 "unaligned"이다. 즉, task prompt 한해 학습을 했기 때문에 natural prompts에는 잘 반응하지 않는다. 이러한 문제를 해결하는 distillation 방법을 제안하며 dDPO (distilled DPO) 라 한다.

 

 

 

먼저 teacher model에 의해 데이터셋의 outputs들은 rank되고, chat model에 distilled DPO를 적용해 학습하면 intent alignment 성능이 향상된다. 이 제안하는 방법은 파인 튜닝 동안 추가적인 샘플링 없이 몇 시간만에 학습이 가능하다. 또한 저자들은 AI가 생성한 Feedback (AIF) preference dataset을 사용했기 떄문에 human annotation이 필요 없다는 점도 강조한다. Mistral 7b에 dDPO 방법을 적용하여 학습하여 저자들이 공개한 모델은 Zephyr7b이며 MT bench에서 Llama2-chat-70b의 성능을 능가했다고 한다. 해당 논문은 허깅페이스 팀에서 공개한 Tech report인 만큼 코드와 모델, 데이터부터 튜토리얼까지 모두 친절하게 제공된다. https://github.com/huggingface/alignment-handbook

 

 

 

✲ Introduction

 

Smaller open LLM은 초기 GPT-2와 유사한 모델로부터 Chincilla scaling laws에서 제안한 "compute-optimal" 한 토큰의 양보다 훨씬 많은 토큰을 학습하는데 사용하였다. 이전 연구에서 이러한 모델은 distilled supervised fine-tuning (dSFT)를 통해 추가로 학습하면 SFT accuracy를 높일 수 있음을 보여주었다. 이 접근 방식에서는 more capable teacher model의 출력이 student model의 supervised data로 사용된다.

 

Distillation은 다양한 task에서 open model을 효과적으로 향상시킬 수 있다는 것이 입증되었지만, teacher model의 성능까지는 도달하지 못한다. 사용자들은 이러한 모델이 "intent aligned"되지 않았음을 지적했다. 즉 모델이 사람의 선호대로 행동하지 않는다는 것이다. 이는 쿼리에 대해 올바른 응답을 제공하지 않는 출력으로 이어지는 경우가 많다.

 

해당 논문에서는 distillation을 통해 small open LLM을 완전히 aligning하는 문제에 대해 다룬다. 메인 스텝은 teacher 모델의 앙상블로부터 얻은 AI Feedback (AIF)을 선호 데이터로 활용한다는 것이다. 그리고 dpo objective를 통해 distilled DPO를 적용한다. human annotation과 샘플링이 필요없다는 점에서 PPO와 다른 RLHF 접근법들과는 다르다고 강조한다. 또한 small LM을 활용함으로써 chat model 학습에 16개의 A100으로 몇 시간만에 학습할 수 있다.

 

이를 검증하기 위한 실험은 Mistral-7B를 사용하여 학습하고 align된 모델은 ZEPHYR-7B라고 한다. 먼저 UltraChat 데이터를 기반으로 dSFT를 사용하고, UltraFeedback 데이터에서 수집된 AI Feedback을 사용한다. 마지막으로 이 피드백을 기반으로 dDPO를 적용한다. 실험에 따르면 7B 모델은 사람의 피드백에 맞춰 조정된 70B 모델과 비슷한 성능을 달성할 수 있다. 이 실험은 standard academic benchmarks 뿐 아니라 conversational capabilities 또한 향상된다는 것을 보여준다. 그리고 preference learning이 이러한 결과를 달성하는데 중요하다는 것을 보여준다.

 

 

 

✲ Method

 

 

 

해당 논문의 목표는 오픈 소스 LLM을 사용자의 의도에 맞게 align하는 것이다. 작업 전반에 걸쳐 prompted generation을 통해 쿼리하는 더 큰 teacher model $\pi_{T}$ 에 대한 접근을 가정한다. 목표는 student model $\pi_{\theta}$를 생성하는 것이며, 이 과정은 위 Figure 2처럼 InstructGPT와 유사한 단계를 따른다.

 

 

‣ Step 1 - dSFT (distilled Supervised Fine-tuning)

 

Law LLM으로부터 시작하여 먼저 user prompts에 응답하도록 학습한다. 이 단계는 원래 고품질의 instruction, responses 데이터셋에 대한 SFT를 통해 수행되었다. 이 대신 teacher LLM이 주어지면 모델이 instruction과 responses를 생성하도록 하고, 이에 대해 모델을 직접 훈련할 수 있다. 이 과정을 distilled SFT라 한다.

 

이 dSFT 접근법은 self-instruction protocol (참고) 을 따른다. $x_1^0, ... , x_J^0$ 을 다양한 주제 영역으로 구성된 a set of seed prompts라 한다. 데이터셋은 teacher가 instruction에 응답하고, 응답에 따라 instruction을 개선하는 데 사용되는 반복적인 self-prompting을 통해 구성된다. 각 $x^0$로부터 먼저 응답을 샘플하고 $y^0 \sim \pi_{T}(\cdot|x^0)$, 수정을 위한 prompt를 사용해 새로운 instruction을 샘플링함으로써 수정한다 $x^1 \sim \pi_{T}(\cdot | x^0, y^0)$. end point는 마지막 데이터셋 $\mathcal{C}=\{ (x_1, y_1), ..., (x_J, y_J) \}$ 이며 Distillation은 SFT로 수행된다.

 

 

 

‣ Step 2 - AIF (AI Feedback through Preferences)

 

Human Feedback은 LLM aligning에 추가적인 시그널을 제공할 수 있다. Human feedback은 기본적으로 LLM 응답 품질에 대한 선호도를 통해 주어진다. distillation을 위해, 이 대신 다른 모델이 생성한 outputs에 대한 teacher model의 AI preferences를 사용했다. 이는 UltraFeedback 데이터의 approach를 따른 것이며 model outputs에 대한 preference를 얻기 위해 teacher를 사용하는 방법이다. 

4개의 모델이 있을 때 각각을 $\pi_1, ... , \pi_4$ (e.g. Claude, Falcon, Llama, ..) 이라 하고 각 모델이 생성하는 outputs를 $y^1 \sim \pi_1 (\cdot |x), ... y^4 \sim \pi_4 (\cdot | x)$ 라 한다. 그리고 GPT-4와 같은 모델을 teacher model $\pi_{T}$로 사용해 각 응답에 대한 scoring을 매긴다 $s^1 \sim \pi_{T}(\cdot | x, y^1), ... , s^4 \sim \pi_{T}(\cdot | x, y^4)$. 각 프롬프트 $x$마다 이러한 scores를 모은 후에 가장 높은 점수의 응답인 $y_w$를 저정하고 $y_l$ 은 나머지 응답 중에 랜덤으로 하나 정한다. 최종적으로 ($x, y_w, y_l$) 으로 구성된 feedback datasets를 구성한다.

 

 

‣ Step 3 - dDPO (distilled Direct Preference Optimization)

 

마지막 단계의 목적은 선호 모델에서 $y_l$보다 $y_w$에게 높은 순위를 매길 likelihood를 최대화하도록 $\pi_{dSFT}$를 수정하는 것이다. DPO이전 방법에서는 선호 모델은 student LLM $\pi_{\theta}$를 reward function $r_{\theta}(x,y)$에 의해 결정되었으며 PPO와 같은 RL 알고리즘으로 $\theta$를 optimize했다.

 

 

 

DPO는 더 간단한 접근법을 사용한다. key observation은 optimal LLM policy $\pi_*$와 original LLM policy $\pi_{dSFT}$의 측면에서 optimal reward function $r^*$을 도출한다는 것이다. 이를 preference model에 plugging함으로써 (1)과 같은 objective를 제안하였다.

 

 

그리고 위 3가지 training procedure를 통해 dSFT를 학습한다.

 

 

✲ Results

 

conversational capabilities

 

 

academic benchmark

 

 

 

 

위의 메인 결과 외 눈여겨 봤던 결과는 위 Figure 3이다. DPO는 쉽게 오버피팅된다는 단점이 있다. (참고)

Zephyr-7b 학습 과정에서 DPO 1에폭 학습 후에 모델이 크게 오버피팅 되는 것을 관찰했다. 하지만 이는 놀랍게도 MT-bench 및 AlpacaEval의 다운스트림 성능에는 해를 끼치지 않았다고 한다. figure 3에서 볼 수 있듯이 가장 강력한 모델은 SFT 1에폭 후 DPO 3에폭 학습하여 얻어졌다. 그러나 SFT 모델이 1 epoch 이상 학습된 경우 DPO 튜닝 단계는 실제로 더 긴 학습이 성능 회귀를 유도한다는 것이 관찰되었다.

 

 


 

 

human feedback을 얻기 어렵고, 대형 모델 튜닝이 어렵다는 점을 직관적인 아이디어로 해결한 논문이다. 대형 모델들의 지식을 작은 모델에 distillation하기 위해 데이터라는 연결고리를 사용했다. 대형 모델로부터 샘플된 피드백 데이터를 teacher model에게 scoring을 맡기고 이 데이터로부터 작은 모델에 SFT+DPO하여 대형 모델의 지식을 전이하겠다는 것이다. 데이터를 샘플할 대형 모델과 teacher model에 따라 성능이 많이 달라지려나? aligning하고 싶은 user intent에 대한 프롬프트가 많고 모델에 최적화할 수 있다면 적용해 볼 수 있을 것 같다.

 

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.