ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [RLHF] The wisdom of hindsight makes language models better instruction followers(HIR)
    ✨ AI/AI papers 2023. 3. 26. 19:23

    The wisdom of hindsight makes language models better instruction followers

    https://arxiv.org/abs/2302.05206

     

    The Wisdom of Hindsight Makes Language Models Better Instruction Followers

    Reinforcement learning has seen wide success in finetuning large language models to better align with instructions via human feedback. The so-called algorithm, Reinforcement Learning with Human Feedback (RLHF) demonstrates impressive performance on the GPT

    arxiv.org

     

    Abstract

    RL을 사용하면 HF instruction을 통해 더 좋은 Alignment를 가진 LM로 finetuning 할 수 있다는 것을 그동안 많은 연구들에서 보여줬다.

    RLHF(Reinforcement Learning with Human Feedback)
    ‘good’ text라는 것은 주관적이고 문맥에 의존하기 때문에 정의하기 어렵다. creativity한 문장이 좋은 문장일 수도, 정보를 많이 포함한 문장이 좋은 걸 수도, truthful한 (chatgpt는 없는 사실도 지어낸다.) 문장이 좋을 수도, 그리고 LLM이 만들어낸 code sinppet이 정말로 실행되면(chatgpt 코드가 안돌아갈 수도 있다) 좋은 것일 수도 있다. 이러한 속성들을 loss function에 담기는 불가능하며 대부분의 LM은 단지 다음 토큰의 cross entropy loss로 training된다. inference시 loss가 아닌 reference sentnece를 가지고 사람의 preference를 반영한 bleu, rouge 스코어로 평가를 하긴 하지만 이도 한계가 있다. 그래서 LM이 generate한 text를 가지고 사람이 feedback을 주고, 이 feedback을 성능평가로 사용하는 것에서 더 나아가 loss로 사용해 model을 optimize하면 어떨까? ⇒ RLHF
    RLHF : human feedback(reward)로 LM을 RL을 사용해 직접적으로 Optimize하는 방법

     

    RLHF가 인상적인 성능을 보여주긴 했지만, RL 알고리즘은 reward, value network를 training하기 위한 복잡하고 추가적인 Training pipeline이 필요하다. 그래서 HIR에서는 original feedback을 Relabeling함으로써 instruction으로 바꾸고, 모델을 더 나은 alignment로 supervised manner로 tuning하는 방법을 제안한다. 원래의 LM과 pretraining 파이프라인을 재사용할 뿐 추가적인 파라미터가 필요없어 RL보다 더 간단하다는 장점이 있다. 이는 ‘Instruction alignment’문제를 RL의 ‘goal-reaching’ 문제처럼 보고 풀었다고 볼 수 있고 Hindsightly instruction을 relabel하는 reward-free 접근법의 한 종류라고 볼 수 있다. 이 방법은 supervised finetuning처럼 간단하지만 supervised finetuning을 훨씬 능가하는 성능을 보여준다.

     

     

    Introduction

    GPT3가 prompt learning을 통해 finetuning 없이도 여러 task가 가능하다는 것을 보여주긴 했지만, 최근 연구들에는 LLM을 instruction을 통해 내놓는 behavior가 의도한 대로 나오지 않을 수 있다는 것을 보여주고 있다. LLM이 사실이 아니거나 toxic text를 generate하거나 instruction을 따르지 않는 output을 generate할 수 있기 때문에 unintened behavior는 바람직 하지 않다. 이를 위해, 즉 LM의 output이 사람의 지시에 따르도록 하는 많은 finetuning algorithm들이 연구되었다. 가장 널리 사용되는 접근 법은 RL을 이용하는 것이다. learned 또는 manually define된 alignment score를 optimize하는데 RL이 사용된다. 크게는 3가지 접근 법으로 나뉜다.

     

    1. Optimize for trained alignment score module by using PPO algorithm
      1. rather complex, sesitve to hyperparameters
      2. require additional training in the reward model, value network
    2. Imitation learning to a Final-anser or Reward-Model filterd dataset
      1. less data-effective as it only makes use of the success instruction-output pairs, align 되지 않은 데이터는 버려진다.

     

    그래서 이 논문에서는 Simple한 Finetuning algorithm이면서 동시에 오직 성공한 instruction pair만 아닌 나머지 데이터셋도 사용할 수 있는 알고리즘을 제안한다. 이 논문에서는 먼저 Instruction alignment와 augmented goal 을 만드는 goal-reaching RL과의 Connection에 대해 이야기 한다. 일단 Instruction을 Goal로 볼 수 있고 language model을 goal-conditoned policy로 보자는 것이다. 그렇게 되면 HER과 같은 goal-conditioned RL을 alignment problem에 적용할 수 있다는 것이다.

     

    그렇게 고안된 이 알고리즘은 Hindsight Instruction Relabeling(HIR)이라고 하고 hindsight fashion을 통해 instruction 을 relabeling한다. 이 알고리즘은 Instruction-ouput pair dataset을 생성하는 Online sampling과 instruction을 relabeling하여 supervised training하는 Offline learning 이렇게 두개의 루프가 존재한다. relabeling을 위해 failure data를 사용하는 HER의 아이디어를 빌려오고 contrastive instruction labeling을 통해 성능을 향상시켰다.

     

    Key attribution

    • hindsight instruction relabeling을 통한 feedback을 learning하는 새로운 관점의 알고리즘을 제안, LM의 Alignment problem과 Goal-conditioned RL의 Connection을 찾음
    • Data-effective하면서 동시에 RLHF처럼 추가적인 RL training pipeline이 필요하지 않은 novel algorithm

     

    Related Work

    RLHF(e.g. Instruct GPT(ChatGPT), WebGPT)

    RLHF가 처음 제안된거는 2011년, 과거에는 주로 Human preference를 추론하고 모델링 하기 위해 Inverse RL을 사용한 방법들이 제안되었었는데, 최근 ChatGPT에 사용된 InstructGPT를 시작으로, RL알고리즘을 사용해 Human Preference 쪽으로 LM의 alignment를 개선한 방법이 화두가 되었다. 사람이 직접 작성한 Instruction-ouput pair를 ground truth로 두고, GPT(LM)이 내는 ouput이 이를 따라가게끔 finetuning한 것이다. 하지만 이는 엄청나게 많은 양의 사람 feedback데이터가 필요하고 이게 성공의 요인이라고 볼 수 있다. InstructGPT(ChatGPT), WebGPT 등은 General purpose Chatbot을 개발하기 위한 방법이고 HIR에서는 pretrained model을 finetuning 하는 process에 초점을 두었고 lighter-weight approach이다.

     

    Two-stage RL

    그동안 오프라인 강화학습을 다루는 많은 범주의 연구들이 있었다. Trajectory Transformer, Decision Transformer, TAP(Trajectory Autoencoder Planner) 등. 온라인 exploration을 위해 transformer 사용하는 OnlineDT까지. 최근에는 online exploration과 offline training을 번갈아 하는 것과 유사한 Algorithm Distillation(AD)도 제안되었다. HIR과 AD는 완전히 다른 문제를 해결하고 있다. HIR은 LM alignment를 RL을 사용해 improve하는 것이고 AD는 classical control problem을 다룬다.

     

    LM with Reasoning Task

    reasoning task는 explicit하게 reasoning step을 요구한다. (e.g. math solving.) 최근에는 Finetuning 또는 Prompt 를 통해 LM이 multi-step reasoning을 하는 것에 많은 연구들이 집중되고 있다.

     

    일단 논문에서 언급된 것은 이정도인데, 해당 논문에서 강조하는 contribution이 무엇인지 이해하려면 Prompt learning, Instruction tuning, RLHF, Algorithm Distillation, Final-Answer RL 등에 대해 조금은 알고 있어야 한다.

     

     

    In-context learning, few-short (prompting) learning, prompt engineering

    Brown, Tom, et al. "Language models are few-shot learners." Advances in neural information processing systems 33 (2020): 1877-1901.

    LMpretrain하는 것은 단지 연산량, 데이터셋 사이즈, 파라미터수가 bottleneck 없이 커지기만 한다면 성능이 향상됨이 Scaling laws for LMs라는 논문에서 실험으로 입증한바가 있다. GPT3 parameter175Bilion의 파라미터를 가지고 있는 언어 모델이다. 이 파라미터를 Finetuning 하는 것은 사실상 불가능하며 LM performance 위해 모델을 크게 만들되 Fewshort learning으로 finetuning안해도되게 만든 것이 GPT3이다.

     

    LM 자체에는 Task 수행할 수 있는 능력이 있으니 Finetuning말고 활용하자는 것이 키 아이디어이다. 기존에는 대규모 언어 데이터셋으로, 이전 시퀀스(이전 토큰들)이 주어지면 다음 토큰을 예측 할 수 있도록 maximum loglikelihood traning한 LM이 있으면, 각 classification, sentiment prediction, translation 등의 downstream task로 Fine tuning하는(task dataset으로 parameter reupdate) 방식을 사용하였다. 하지만 GPT는 기존에 task 데이터 셋을 LM Pretraining에서 몇 개를 주고 Downstream task에 대해서는 task label 같은 것만 앞에 붙여 Fine tuning 없이 바로 Down stream task를 수행할 수 있도록 만들었다. few shot task에 대한 query를 prompt라고 한다.

     

    Prompting method

    Kojima, Takeshi, et al. "Large language models are zero-shot reasoners."  arXiv preprint arXiv:2205.11916 (2022).

    그리고 task에 대한 few shot example(Prompt)을 줄 때  Model이 Prediction을 더 잘 할 수 있도록 Prompt를 바꾸거나 Prompt에 추론에 대한 context를 주는 등 다양한 prompting methods들도 연구되었다. CoT(Chain-Of-Thought)라는 방법은 Prompt에 추론한 내용도 함께 줌으로써 추론까지 대답하게끔 LM을 학습하는 방법이다. CoT를 적용하는 것 만으로도 System-2 task같은 초등학생 수준의 계산 문제도 잘 해결되었다. Few-shot learning의 성과보다 promptengineering하면 성능이 더 향상되거나 어려운 task푸는 것이 가능하다는 것을 보여준 것이다. Prompt“let’s think step by step”만 추가해도 zero-shot learning이 가능하다고 한다.


    구체적인 문제 풀이 과정의 예시를 제공하는 것을 few-shot prompt, 문제와 해답 템플릿만을 제공하는 것을 zero-shot prompt라고 하며,  Few shot CoT 는 한개의 답을 단계별로 나눠서 예시를 주고(human engineering 왼쪽 그림) zero shot CoT(오른쪽 그림)은 문제에 대한 구체적 예시를 주지 않아도 되고 LLM이 답변한 추론 내용을 다시 Prompt로 주는 방법이다.

    https://arxiv.org/pdf/2205.12548.pdf

    그리고 RLPrompt라는 것도 있는데 Discrete Prompt를 optimize하는데 RL이 적용되었고 Prompt generator가 RL의 Policy가 된다.

     

     

    Instruction Tuning

    그리고 Instruction Tuning은 Instruction template 정해 두고 기존 task데이터를 instruction에 맞게 수정한 뒤 tuning 에 포함되지 않은 task에서 evaluete하는 것이다. FLAN, FLAN-T5 등의 모델은 instruction tuningLMunseen task에 대한 zero-shot 성능을 향상 시킬 수 있음을 보여줬다. 해당 논문(HIR)에서 Base model로 사용한 모델이 FLAN-T5이다. FLAN-T5는 Instruction tuning에다가 CoT를 적용한 모델이다.

     

    Instruct GPT

    Ouyang, Long, et al. "Training language models to follow instructions with human feedback." arXiv preprint arXiv:2203.02155 (2022).

    ChatGPT에 사용되었다는 InstructGPT는 RLHF가 사용된 대표적인 모델이다. 일단 prompt dataset에서 prompt를 샘플하고 사람이 desired output을 라벨링한다. 그리고 이 데이터는 GPT3에 supervised로 Finetune된 후, 이렇게 학습된 GPT가 내는 output 중 사람이 Ranking을 매겨 Reward model을 학습시킨다.

    그리고 Reward model을 사용해서 PPO로 Policy 모델이 Human preference(Instruction)을 따르도록 최적화 된다. 이 논문을 보면서 신기했던 점은 SFT는 1epoch만에 overfitting했지만 Reward model, policy model을 만드는데도 overfitting해도 epoch를 더 늘린 SFT가 더 좋았다고 한다.

    SFT GPT에서 내논 output 중 사람이 가장 선호하는 y와 나머지 k를 모두 이용하는 것이 아니라(overfitting 방지) single batch에 2개씩 뽑아 사용했다고 한다. Y_w에 사람이 선호하는 prob을 policy 모델에 할당하고 SFT model의 KL panelty를 줘서 reward model이 over optimization하는 것을 방지했고 PPO gradient와 Pretraining gradient를 섞어서 기존의 NLP task도 잘 수행할 수 있도록 했다고 한다.

     

     

    In-context RL with Algorithm Distillation

    Laskin, Michael, et al. "In-context reinforcement learning with algorithm distillation." arXiv preprint arXiv:2210.14215 (2022).

    AD는 먼저 Any gradient-based RL algorithm으로 Task에 대한 history-conditoned policy를 학습한 후 이 Policy로 각 task에 대해 dataset D를 생성한다. 그런다음 source algorithm의 behavior를 Sequence model에 distillation한다. 이 Sequence model은 log history에 action prob을 mapping하며 NLL(negative log likelihood) loss를 사용한다.

     

    직관적으로 fixed parameter를 가진 Sequence model은 Source RL algorithm을 armotise해야하며 exploration, temporal credit assignment 등의 복잡한 behaviour를 나타낼 수 있어야 한다. RL policy는 source algorithm의 learning history를 통해 향상되고 정확한 action prediction을 위해서는 Sequence model이 이전 context에서 현재 policy를 추론할 뿐 아니라 개선된 policy도 추론해야 한다. 

     

    Final-Anser RL(FARL)

    Uesato, Jonathan, et al. "Solving math word problems with process-and outcome-based feedback." arXiv preprint arXiv:2211.14275 (2022).

    few shot prompting, supvervised FIne tuning, RL등으로 Policy가 미리 학습된 모델이 있을 때, model ouput score로 finetuning한 ORM, 마지막 대답만으로 Correctioness로 finetuning한 Final-Answer RL, 문제를 풀어가는 과정에서 각 step에서의 score가 맞았는지 가지고 fine tuning한 PRM-RL 위 세개를 비교했을 때 Final-Answer RL이 가장 좋았다고함. 이 논문은 자세히 읽어보지 않아서 모르지만 지금 소개하는 HIR에서는 Imitation Learning을 통해 alignment problem을 해결한 방법은 주로 정답이 맞은 데이터셋 혹은 높은 점수를 받는 데이터셋 만을 사용하는 경우가 많고 이는 Data inefficient하다고 말하고 있다. FARL는 실패한 케이스를 sampling하긴 하지만 성공한 케이스만 사용하게 된다.

     

     

    HIR(Hindsight Instruction Relabeling)

    Instruction Follwing as Goal-conditioned RL

    Goal-conditoned RL에서 MDP는 State, Transition prob, Action, Reward, Neat state로 이루어진 MDP에 goal G가 추가된다. 그리고 Reward function은 State S, Action A와 함께 Goal G도 Input으로 받는다. 그리고 policy는 State S, Goal g가 Conditioned된다. 

     

    LM에서 Prompt based finetuning을 할 때 이 prompt를 goal로 보면 HER(HindSight Experience Replay)와 같은 goal-conditioned RL을 Alignment problem 적용할 수 있게 된다는 것이다. LM M이 instruction prompt p와 initial query token sequence q={q_0, .. q_i}를 input으로 받고 autoregressively 다음 토큰을 예측한다. e_{i+1}=M(p,q,{e_0, .. e_i}), e is embedding vector of q

     

     

    NLP task를 생각해보면 Policy가 내는 다음 토큰이 다음 상태가 된다. 때문에 여기서 LM을 goal-conditioned policy로 보는 동시에 Transition model(world model)로도 볼 수 있다.

     

     

     

    Hindsight Instruction Relabeling

     

    이 논문에서는 SimpleFinetuning algorithm이면서 동시에 오직 성공한 instruction pair만 아닌 나머지 데이터셋도 사용할 수 있는 알고리즘을 제안한다. hindsight fashion을 통해 instruction relabeling하고, Instruction-ouput pair dataset을 생성하는 Online samplinginstructionrelabeling하여 supervised training하는 Offline learning 이렇게 두개의 루프가 존재한다. relabeling을 위해 failure data 사용하는 HER의 아이디어를 빌려오고 contrastive instruction labeling을 통해 성능을 향상시켰다. 전체적인 동작 과정은 online sampling으로 생성된 데이터셋으로 Offline instruction relabeling을 하고 policy를 improve하는 것을 반복한다.

     

    Online Sampling

    prompt(instruction) p와 query q가 주어졌을 때 output sequence o를 얻는다. q 는 training dataset으로부터 샘플해오고 p는 pre-defined sentence로 초기화된다. 이 output sequence를 낸 prompt p는 offline relabeling 단계에서 올바르게 다시 alignment된다. 

     

     

    Offline Relabeling

    Offlien Relabeling은 sample된 o에 대해 맞는 prompt p*으로 다시 라벨링을 해준다. 이는 HER에서 sparse reward 환경에서 마지막 받은 reward로 중간 과정에 Pseudo reward로 바꿔주는 것과 비슷하다. HIR에서는 Feedback data로 Score network나 Reward model을 따로 만들지 않고 정답이 맞았으면 "generate a correct answer to this problem", 틀렸으면 "generate a wrong anser to this problem"과 같이 Scriptable한 Relabeling방법을 사용한다. 아이디어는 굉장히 간단하다.

     

    Contrastive Instruction following

    Contrastive instruction loss 텀을 통해 다른 instruction에 대해 같은 output을 매핑하지 않도록 하고

     

    Entropy Regularization

    특정 p가 주어졌을 때 ouput에 대해 Negative Entropy term을 추가해서 sampling phrase가 better exploration을 더 하여 너무 빨리 Converge 하지 않도록 하였다.

     

    최종 Loss는 Cross entropy(SFT) + Contrastive loss + Entropy Regularizer가 된다.

     

     

    Algorithm

     

     

    그리고 이 논문에서는 AD, RLHF, FARL과 HIR을 비교한 내용도 있다.

     

     

    AD와 비교했을 때는 Two-stage RL이라는 점(Online sampling & Offline Training)은 같지만 AD는 control task를 위한 모델이고 HIR은 LM의 Alignment를 수정하는 task에 초점을 맞추고 있다. 그리고 HIR은 explicit한 returnm(reward) model이 필요하지 않다.

     

    RLHF와 비교했을 때는 feedback을 통해 Instruction Alignment problem을 해결하고자하는 것은 같으나 RLHF는 추가적인 RL training이 요구되지만 HIR은 LM의 파라미터만 수정하게 된다.

     

    FARL과 비교했을 때는 Failure cases도 Learning에 사용되는 점은 같지만 오직 맞은 output만 필터링해서 사용한다.

     

     

    Experiment

     

    12 BigBench라는 LM reasoning tasks evaluate하였다. task는 매우 다양한 task 담고 있는데 logical understanding이 필요한 logical deduction과 수학 계산, geometric shape을 물어보는 문제 등의 object counting 문제 등이 있다. InstructGPT, AD에서 사용된 human feedback 데이터셋을 얻을 수 없어서 위 실험에서는 PPO라고 명칭했다고 한다. 

    실험 결과는 PPO, Final Anser RL에 비교해서 11.2%, 32.6% 성능이 올랐다.

     

    HIR은 Hindsight Instruction relabeling을 통해 feedback을 learning하는 새로운 관점의 알고리즘을 제안하였고 LM의 alignment problem과 goal-condition RL의 connection을 찾았다. Data effective하면서 동시에 RLHF처럼 추가적인 RL training pipeline이 필요하지 않은 novel algorithm이라고 주장하고 있다.

     

    논문을 다 읽은 후기로는 조금 허무한 느낌이 들었다. RLHF traning이 현실적으로 엄청나게 많은 양의 human feedback data를 가져야 하므로 해당 알고리즘은 이를 흉내내면서 쉽게 training해볼 수 있다는 장점은 있지만 HIR이 RLHF와 비교 대상이 되는지는 조금 의문이다.

     

    댓글

Designed by Tistory.