새소식

자연어 NLP

[RLHF] ORPO: Monolithic Preference Optimization without Reference Model

  • -

ORPO: Monolithic Preference Optimization without Reference Model

 

 

link : https://arxiv.org/pdf/2403.07691

 

 

 

 Abstract

 

최근 Preference alignment 알고리즘이 좋은 결과를 보이는 동안, SFT (supervised fine-tuning) 과정이 convergence를 위해 꼭 필요한 과정인가는 논의되지 않았다. 본 논문에서는 preference alignment 맥락에서 SFT의 중요한 역할에 대해 연구하고, 선호하지 않는 생성 스타일을 위한 minor penalty가 preference-aligned SFT에서 충분하다는 점을 강조한다. 이를 위해 간단하면서도 혁신적인 reference model-free monolithic odds ratio preference optimization, ORPO 알고리즘을 제안함으로써 불필요한 추가적인 preference alignment phrase를 제거한다. 또한 경험적으로나 이론적으로 odds ratio가 다양한 사이즈의 SFT에서 선호하는 스타일과 비선호하는 스타일을 대조하기 위한 합리적인 선택임을 보인다.

 

 

 Introduction

 

다량의 데이터로 학습된 사전 모델을 특정 task에 잘 동작하게 하려면 Instruction-tuning과 같은 downstream task tuning을 하지만 이 과정은 harmful or unethical output을 생성할 수 있다. 이를 위해 사람의 선호도를 반영하는 RLHF, DPO와 같은 방법을 사용하며 몇몇 downstream tasks에서 모델이 좋지 않은 문장을 생성하는 것을 방지하는데 성공적인 결과를 보여주었다. 그러나 현재 존재하는 preference alignment methods는 주로 multi-stage 프로세스가 필요하다. 기존 RLHF에서 필요했단 reward model training 프로세스는 DPO에서 제거함으로써 process를 1단계 줄이는데 성공하였지만 여전히 SFT process가 warm-up phrase로써 human-alignment 학습 전에 필요하다. 

해당 논문에서는 pairwise preference 데이터셋으로 SFT를 하는 역할과 그 효과를 연구하고, 간단하면서 혁신적인 monolithic alignment method, odds ratio preference optimization (ORPO)를 소개한다. 이는 undesired generation style에 대해 패널티를 주게 되며 SFT과정이 필요 없다.

 

 

 Related Works

Alignment with Reinforcement Learning

 

RLHF (RL with human feedback)은 일반적으로 두 독립적인 evaluated instance 사이의 pairwise competition의 확률을 추정하기 위해  Bradley-Terry 모델을 사용한다. 그리고 이를 이용해 instances를 scoring하기 위해 Reward model을 학습한다. 학습된 리워드 모델을 통해 Proximal policy optimization (PPO)와 같은 RL tuning과정을 진행하며 사람의 선호 점수를 최대화 하는 답변(리워드 모델의 예측 점수가 큰)을 생성할 수 있도록 LM을 aligning 하게 된다. 

그러나 이 방법은 PPO알고리즘의 불안정성으로 야기되는 extensive hyperparameter searching가 필요하다는 문제가 있고, 리워드 모델의 성능에 크게 영향을 받는다는 문제 등이 있다. 이는 안정적인 preference alignment algorithm의 필요성을 야기한다.

 

‣ Alginment without Reward Model

 

기존 RLHF 방법에서 reward model 피팅 과정을 통합한 DPO 알고리즘과 DPO의 오버피팅 문제를 완화하는 IPO 알고리즘, pairwise preference dataset이 필요없는 KTO, ULMA 등 Reward model 없이 human preference alignment methods이 많이 제안되고 있다. 

 

 Alignment with Supervised Fine-tuning

 

위에서 언급된 Preference alignment 방법들은 SFT과정을 통해 Reference model를 얻고 이 Reference model과 현재 업데이트하고 있는 모델의 policy가 크게 차이 나지 않도록 제한함으로써 desired result로 모델이 converge될 수 있게 한다.

반면 필터링 된 데이터셋을 통해 SFT만을 수행해 human-aligning을 가능하게 하는 몇몇 접근법들도 존재한다. 이 방법들은 소량의 fine-grained filtering, curation을 통해 SFT 과정을 수행함으로써 이 과정만으로 helpful LM assistant를 만들기에 충분하다는 것을 보여주었다. 이러한 연구는 alignment 맥락에서 SFT의 중요성은 잘 보여주지만, SFT에 preference alignment를 통합하기에는 아직 부족해 보인다.

 

 

 The Role of Supervised Fine-tuning



저자들은 SFT의 loss function에 대해 분석하고, SFT모델의 선호 이해 능력에 관한 실험을 통해 prefererence alginment의 initial stage로써의 SFT의 동작에 관해 연구하였다.

 

SFT는 적절한 토큰의 log prob을 증가시킴으로써 사전 학습 모델을 desired domain에 맞추기 위해 중요한 역할을 한다. 하지만 figure 3에서 볼 수 있듯 의도치 않게 원치 않은 스타일의 토큰의 확률도 증가시켜버릴 수 있다. 따라서 원치 않는 스타일을 생성하는 것을 방지하고 식별하는 동시에 domain adaptation의 역할도 보존할 수 있도록 하는 방법이 필요하다.

 

 Absence of Penalty in Cross-Entropy Loss

 

 

Cross entropy loss의 목적은 모델이 reference answer에 대한 예측 로짓이 낮은 경우 패널라이즈 하는 것이다. y_i는 i번째 토큰이 레이블 토큰인지 나타내는 boolean 값이고 p_i는 i번째 토큰의 확률이다. cross-entropy를 단독으로 사용하게되면 y_i는 0으로 설정되므로 non-answer token에 대한 로짓을 위한 패널티나 보상을 직접적으로 줄 수 있는 방법은 없다.

이 cross entropy는 domain adaptation에 효과적이지만, 선호된 응답에 대해 보상할 때 선호되지 않은 답변을 패널라이즈 하는 매커니즘은 없다. RLHF의 목적은 단순히 사용자에게 선호되도록 답변을 생성하는 것 뿐 아니라 모델이 유해하거나 옳지 않은 답변을 생성하지 못하도록 하는 것도 중요하다. 따라서 선호되지 않은 응답에 대한 토큰의 log prob도 선호 응답에 따라 증가하게 되며 이는 preference alignment 관점에서 적합하지 않다.

 

 Generalization over Both Response Styles

 

저자들은 SFT단독으로는 선호, 비선호 응답에 대한 잘못 조정할 수 있다는 것을 실험적으로 보였다. OPT-350M 모델에 HH-RLHF 데이터셋에서 선호된 답변만을 가지고 fine-tuning하였을 때 figure 3과 같이 선호 비선호 확률 모두 증가되는 것을 확인할 수 있었다고 한다.

이는 두 가지 다른 관점으로 해석될 수 있다. 먼저 cross-entropy loss는 모델을 효과적으로 intended domain으로 가이드한다고 볼 수 있다. 그러나 원하지 않은 생성 결과에 대한 패널티 없이는 비선호 응답이 선호 응답보다 더 높게 되는 결과를 가끔 초래할 수 있다.

 

 Penalizing Undesired Generations

 

이전 연구들에서 unlikelihood 패널티가 추가된 loss로 unwanted degenerative traits 를 줄이는데 성공했다. 이에 영감을 얻어 해당 연구에서는 동적으로 원하지 않은 응답에 대해 패널라이즈 할 수 있으면서 각 쿼리에서 rejected token 셋을 만들 필요 없는 monolithic preference alignment method를 제안한다.

 

 

 

 

 Odds Ratio Preference Optimization (ORPO)



ORPO는 선호, 비선호 응답을 구분하기 위해 negative log-likelihood (NLL)에 odds ratio-based penalty를 통합한 것이다. 

 

 

입력 시퀀스 x에 대해 길이가 m인 출력 시퀀스 y를 생성하는 평균 log likelihood가 있을때,

 

 

입력 시퀀스 x가 주어질 때 y 시퀀스를 생성할 odds는 위와 같이 정의된다. 직관적으로 odds = k 라는 것은 모델이 y를 생성하지 않는 것 보다 생성할 가능성이 k배 높다는 것을 의미한다.

 

 

그리고 chosen 응답 y_w와 rejected 응답 y_l에 대한 odds의 비율은 OR로 정의하며 위와 같다. 이는 모델이 y_l보다 y_w를 생성할 가능성이 있는지를 나타낸다.

 

 

최종적으로 ORPO objective는 이 Odds Ratio (OR)에 관한 텀과 SFT에 관한 텀으로 이루어진다. L_SFT는 conventional causal language modeling negative log-likelihood (NLL) loss와 동일하며 refernece tokens를 생성하는 likelihood를 최대화한다. L_OR은 y_l과 y_w를 생성하는 likelihood 사이의 odds ratio를 최대화하는 효과를 준다. 그리고 y_w과 y_l의 log odds ratio를 증가시킴으로써 L_OR이 감소될 수 있도록 log odds ratio를 log sigmoid로 감싸준다. 그리고 desired domain adaption에 초점을 둘지, disfavor generation을 방지하는 것에 초점을 둘지는 파라미터 gamma로 조정한다.

 

 Gradient of ORPO

 

 

 

L_OR의 gradient는 odds ratio loss를 사용하는 것을 더욱 정당화 할 수 있다. 이는 두 가지 텀으로 구성되는데 하나는 잘못된 예측에 불이익을 주는 텀이고 다른 하나는 선택된 응답과 거부된 응답을 대조하는 텀이다. 여기서 d는 데이터 한 세트를 의미하고 입력 시퀀스 x, chosen response y_l, rejected response y_w 이다.

선호하는 응답의 확률이 비선호 응답보다 상대적으로 높을 경우 eq (9)는 0으로 수렴된다. 이는 L_OR의 gradient의 앞 텀(9)이 패널티 항의 역할을 하고 만약 모델이 rejected response를 생성할 확률이 더 높다면 parameter 업데이트를 가속화하는 것을 나타낸다. 

반면, eq(10)의 h(d)는 선호 비선호 응답간 gradient의 weighted contrast를 나타낸다. 특히 1 - P(y|x)는 대응하는 P(y|x)가 낮을 때 기울기를 증폭한다. chosen responses의 경우 likelihood가 증가함에 따라 chosen 응답의 분포에 대한 모델의 adaptaion이 가속화된다.

 

 

 Comparison to Probability Ratio

 

 

 

단순 확률 ratio(PR)을 사용하지 않고 odds ratio(OR)를 사용한 이유는 odds ratio의 안정성 때문이라고 한다. PR에 비해 OR은 선호하지 않는 응답에 대해 더 극단적으로 차별할 수 있는 효과를 줄 수 있다. OR은 PR에 (1-P(y_l))/(1-P(y_w))이 곱해진 값이므로 y_w과 y_l의 확률 차가 더 크다면 더 큰 값을 가지기 때문이다.

PR과 OR에 log sigmoid가 적용된다는 점에서 각 ratio의 스케일은 선호/비선호 likelihood 사이 예상 마진을 결정한다. 그런 의미에서 log simoid loss를 최소화하기 위해서는 PR을 사용할 때 더 극단적으로 차이가 나야 한다. 따라서 PR을 사용하게 된다면 모델은 더 과도하게 비선호 응답에 대한 로짓을 줄이게 되고, 비선호 응답을 너무 억제한다면 domain adaptation 능력을 잃어버릴 우려가 있다.

 

 


 

 Review

 

DPO의 선호/비선호 응답에 대한 확률 마진이 너무 크게 차이가 나게 되는 문제를 IPO에서는 overfitting 문제라고 보고 ORPO는 domain adaptation이 퇴보되는 관점에서 보고 있다. 그리고 특히 ORPO는 기존 DPO계열 모델들이 SFT로 어느정도 파인튜닝 된 모델에 Preference-alignment하는 과정을 하나로 통합함으로써 RLHF 단계 중 하나의 단계를 더 줄였다.

SFT를 할 때 비선호 응답에 대한 토큰의 로짓도 올라가는 것을 발견하고, 이를 간단하면서 창의적인 아이디어로 해결했다는 점이 놀라웠다. 그리고 특히 실험에서 단순 LLM-based evaluation 결과나 벤치마크 결과 뿐 아니라 어휘의 다양성 관련한 실험 등이 잘 짜여져 있었다. odds ratio를 도입하여 log simoid가 적용되기 전 선호/비선호 차를 크게 반영함으로써 두 확률차가 극단적으로 멀어지는 것을 방지할 수 있다. 또한 OR 텀을 통해 기존 SFT에서 비선호 라벨의 로짓을 줄이는 메커니즘이 없다는 점도 해결하면서, 기존 RLHF 프로세스를 우회해 단 1단계만으로 언어 모델의 선호 정렬을 가능하게 한다. 최근 DPO 계열 논문들의 baseline에 ORPO가 많이 추가되고 있고 여기저기에 언급이 많은 이유를 알 수 있었다.

 

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.