ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [RLHF] A General Theoretical Paradigm to Understand Learning from Human Preferences, IPO - (2)
    ✨ AI/AI papers 2024. 3. 25. 05:07

    Week Regularisation and Overfitting

     

     

     

     

    y와 y' 두 라벨이 있을 때, p*(y > y') = 1인 경우 즉 항상 y를 선호할 확률이 있다고 생각해보자. 그러면 BT 모델은 r(y)-r(y') 무한대로 가게 될 것이고 policy pi*에 이를 대입하면 pi*(y') / pi(y) = 0이 될 것이다 즉 pi*(y') = 0이 된다. 이렇게 된다면, KL regularisation을 위한 constant tau는 무시될 것이며 더 deterministic 한 preference를 모델링하게 되어 오버피팅이 일어나게 된다. KL 패널티가 약해지는 문제는 우리가 RLHF 모델링할 때 주로 사용하는 finite preference dataset을 사용할 때 더 두드러진다.

     

    따라서 DPO보다 Standard RLHF 알고리즘을 사용할 때 (Reward model + PPO) 이 문제에 더 robust하다고 볼 수 있는데, 그 이유는 아래와 같다. DPO의 장점이 reward function fitting을 피한다는 것이지만, 우리는 실제로 empirical preference prob은 0과 1사이에 있을 때 reward function은 결국 underfit하게된다. (DPO에서는 reward가 pi(y)/pi_ref(y)로 표현됨)

    이전 연구에서 0과 1사이의 preference prob이 있는 경우 optimal reward는 무한 값을 가질 수 있지만 이러한 값은 피하고 실제로는 RLHF에서 reward function의 regularisation이 중요한 것으로 관찰됐다. 즉, reward function underfitting은 reference policy에 대해 충분히 regularisation된 final policy를 얻는데 중요한 역할을 하고 DPO는 reward function 학습을 피하는 대신 underfitting reward function이 주는 policy regularisation의 이점을 잃는다.

     

    DPO의 overfitting 방지를 위해 early-stopping과 같은 regularisation 방법을 쓸 수 있지만 해당 논문에서는 Psi PO objective 수정버전을 소개하면서 이를 해결하고자 한다. 제안하는 Psi PO objective는 optimal empirical policy를 따르며 preference가 deterministic한 경우에도 reference policy와 가까울 수 있다.

     

     

    5. IPO : Psi PO with identity mapping

     

    위에서 살펴본 DPO의 overfitting 문제는 explicit reward function을 학습하지 않는 것과, unbounded Psi function의 combination에 기인한다. 그래서 Psi function이 bounded 되는 함수로 설정하고, 0-1 사이의 선호도 값을 가지는 데이터셋 이라고 하더라도 KL 텀의 효과가 남아있도록 하면 이를 막을 수 있다는 아이디어가 해당 논문이 제안하는 알고리즘이다. Psi를 Identity 매핑으로 취해 total preference의 직접 regularized optimization으로 이어짐으로써 제공된다.

     

    해당 수식을 최적화 하기 위한 일반적인 방법은 reward를 preference prob p*(y>mu)로 설정하고 RLHF 학습을 하는 것이다. 하지만 RL과 reward model r(y) estimating을 동시에 하는 것은 비용이 크다. 저자들은 DPO에서 영감을 받아 preference dataset이 있다면 8 식의 최적화 하기 위한 empirical solution을 고안했다. 

     

     

     

    Derivations and Computationally Efficient Algorithm

     

     

    저자들은 Derivation을 위해 DPO의 derivation으로부터 시작하고 (https://ebbnflow.tistory.com/382)
    optimal policy의 analytic expression을 Root-finding 문제로 조정한다. (Root-finding problem 으로 조정한다는 의미가 무엇인지는 잘 모르겠다..)

     

    DPO Appendix A.1에 있는 derivation에 따라 (9)번식이 성립하고, (10)에서 우항의 첫번째 텀을 왼쪽으로 넘기고 양변에 log를 곱해주면 (11)식이 나온다. 그렇게 되면 policy pi는 log(pi(y) pi_ref(y') / pi(y') pi_ref(y)) 꼴과 같이 생각할 수 있고 (12)식을 푸는 것을 목표로 하게 된다.

     

     

     

    g(y)는 psi(p*(y>y'))의 expectation이며 psi가 identity function일 때 

    h_pi (y, y') = (p*(y> mu) - p*(y' > mu))/ tau 로 쓸 수 있고 이를 root-finding problem을 single optimisation problem으로 푼다면 (13)과 같은 Loss를 도출 할 수 있게 된다. pi*는 L(pi)의 글로벌 미니마이므로 유니크 솔루션을 갖는다, 이에 대한 Thorem과 Proof가 논문에 나와있는데 언뜻 봤을 때 잘 이해가 안되서 혹시 나중에 이해하게 된다면 추가하겠다.

     

     

    Sampled Loss for IPO

     

     

    결론적으로 IPO loss는 위와 같으며 I(y,y')는 p*(y>y')의 mean을 가지는 Bernoulli 분포로 볼 수 있다. (13)식에서 (16)으로 넘어가기 위해 둘의 equality에 대해 증명해야 한다.

     

     

    해당 prob은 conditional expectation이기 때문에 이들의 equivalnce가 trivial하지는 않다. 따라서 y와 y'의 분포 사이의 symmetry를 이용해야 하지만 대신에 h_pi(y,y')는 y와 y'의 additive function으로 분해 할 수 있다는 것을 사용한다. 

     

     

     

    dataset의 선호 비선호 라벨 (y_w, y_l)은 (16)의 empirical approximation의 두 항 (y, y', I) = (y_w, y_l , 1), (y, y', I) = (y_w, y_l, 0)을 제공하게 되는데 이 symmetry를 이용해, 

     

    이와 같은 empirical loss를 도출할 수 있고 이 symmetry는 loss의 variance를 줄이는데 중요하다.

    (선호와 비선호 y, y'이 데이터셋에 동등하게 존재하므로 이 두 라벨의 구분을 동일한 가중치로 접근할 수 있고 특정 라벨에 치우치지 않게 할 수 있다고 이해했다.)

     

    결론적으로 simple form인 (17)로 도출된다. 이는 log(pi(y_w) / pi(y_l))과 log(p_ref(y_w) / p_ref(y_l))의 log-likelihood ratios와 1/2tau의 gap을 regressing함으로써 policy pi를 optimize할 수 있다는 것을 의미한다.

    regularisation이 약해질 수록 log-likelihood ratio는 커지게 된다. 반면 IPO는 log-likelihoos ratio 사이의 gap를 컨트롤 함으로써 항상 pi_ref (reference policy)쪽으로 regularize할 수 있도록 하며 데이터셋에 대한 overfitting을 피한다.

     

     

     

     

    Experiments

     

     

    IPO는 복잡한 데이터셋에서 적절하지 않아 bandit setting (policy가 softmax인) 의 toy case에 대한 실험만 있다.

     

    DPO는 어떤 action y가 있을 때 이 y가 다른 모든 action에 승리할 때 tau(=beta)에 관계 없이 pi(y) -> 1로 가게 되고 반대로 y가 다른 액션 대해 승리하지 못할 때 tau에 관계 없이 pi(y) -> 0으로 만든다. 동일한 시나리오에서 IPO는 y가 0으로 가지 않고 tau의 강도에 따라 pi_ref에 가깝게 유지될 수 있도록 한다.

     

    실험에서 각 데이터셋은 

     

    DPO는 D2와 D3와 같이 엇갈리는 선호 라벨이 있을 경우에도 Greedy policy로 수렴하게 되는 것을 확인할 수 있다.

    그리고 특정 선호 라벨이 데이터셋에 적게 있을 때는 pi_ref에 가깝게 붙어야 하지만 DPO objective는 이를 촉진하기에 적절하지 않다.

     

     

     

    DPO vs IPO

     

    DPO
    IPO

     

    log-likelihood ratio의 gap을 크게 만드는 DPO와 gap을 1/2*tau로 regressing하는 IPO

     

    DPO_loss = -F.logsigmoid(beta * logits)
    IPO_loss = (logits - 1 / (2 * beta))**2

     

     

     

     


     

    IPO에서는 DPO에서 preference prob으로 표현되는 reward가 bound 되어 있지 않아 해당 텀이 커짐으로써 KL regularization 텀이 약해지며 발생하는 overfitting 문제에 대해 anlaytic하게 지적한 논문이다. 그리고 empirical approximation에 대해 제안한 방법 Psi-PO 중 하나의 케이스인 Psi 함수가 Identity function인 경우 closed form로 풀어 IPO에 대해 제안했다. IPO loss는 간단한 form으로 도출되지만 논문에서도 언급되었듯 복잡한 세팅의 경우 적절하지 않다. 우리는 선호 데이터셋에서 p*가 아닌 sampled preference만 얻을 수 있으므로. 실제로 DPO는 한 prompt에 대해 다양한 선호 결과가 있는 경우 잘 generalize하지 못한다. 비록 IPO 성능이 좋지 못하더라도 DPO가 reward model을 없애므로써 잃는 regularisation 문제를 analytic하게 잘 설명해준 논문이다.

    그리고 선호 셋은 현재 top score를 가진 선호 라벨이 1개정도인 경우가 많기 때문에 (유저의 select를 받은 답변) 토이 케이스에서 보인 D2, D3같은 경우가 얼마나 있을지 모르겠지만 그런 경우가 아니라도, 나머지 답변에서도 유의미한 정보를 학습해야 하는데 DPO는 top score라벨만 쫓게되는 현상이 있을 것 같다. 

     

    댓글

Designed by Tistory.