새소식

인공지능 AI

[CS25 5강] Mixture of Experts (MoE) paradigm and the Switch Transformer

  • -

link : https://www.youtube.com/watch?v=U8J32Z3qV8s&list=PLoROMvodv4rNiJRchCzutFw5ItR_Z27CM&index=5

 

paper : https://arxiv.org/pdf/2101.03961.pdf

Mixture of Experts (MoE) paradigm and the Switch Transformer

 

이번 세미나 주제는 "Sacling Transformers through Sparsity"이다.

 

먼저 뉴럴 언어 모델의 Scaling에 대해 얘기하면서 세미나는 시작한다. 트랜스포머는 강력한 성능으로 nlp, vision 분야를 섭렵하고 있지만, 과거에는 데이터 셋이 적거나 Sparsity를 포함하는 경우에 트랜스포머를 많이 사용하지 않는다. 트랜스포머가 아주 복잡하고 큰 데이터셋의 correlation을 잘 모델링할 수 있지만 그렇지 않은 경우에는 굳이 모델링이 무거운 트랜스포머를 적용할 필요가 없기 때문이라고 생각했기 때문이다.

2020년에 OpenAI가 발표한 "Scaling Laws for Neural Language Model"에 대해 소개하면서, 딥러닝 분야에서 강력한 성능을 얻기 위해서 집중해야할 방식 중 하나가 'Scaling'이라고 이야기 한다. 해당 논문에서는 모델 성능이 컴퓨팅 또는 파라미터 측면에서 모델 크기와 평행하게 올라가는 것을 실험적으로 입증한 논문이며 큰 파장을 불러 일으킨 논문이기도 하며, 이 논문으로 대형 연구실에서는 모델의 사이즈를 늘리기 위한 경쟁이 이뤄지고 있다.

 

해당 논문에서 흥미로운 부분은, 더 큰 모델이 더 sample-efficient하다는 것을 발견했다는 점이다. 따라서 고정된 컴퓨팅 자원이 있을 때, 우리는 최적의 모델 사이즈를 알 수 있다고 한다. 해당 실험 결과에 따르면, 작은 모델로 많은 훈련 steps을 거치는 것보다 큰 모델을 학습 시키는 것이 더 낫다고 한다. 하지만 해당 논문에서는 dense model에 대해 초점을 맞추고 있고(즉 모델 차원을 늘리는 것), 해당 세미나에서는 sparsity에 대해 초점을 맞춘다.

 

Consider a New Axis: Sparsity

Hypothesis: Scaling the sparsely used parameters -- with a fixed computation per example -- is independently useful.

해당 세미나에서 초점을 두는 것은 희소하게 사용되는 파라미터를 확장하는 것은 독립적으로 유용하다는 것이다. 즉, dense layer만으로 모델을 스케일링 하는 것은 컴퓨팅 자원이나 복잡도가 기하급수적으로 증가하므로, 파라미터 수가 적은 sparse model을 스케일링 하자는 것이다.

Mixture of Experts model(MoE Model)

기본적으로 우리가 네트워크를 학습할 때, 입력에 따라 sparsely activated weight가 있을 것이다. 모든 입력은 대략 비슷한 양으로 계산되지만 다른 가중치가 적용된다. 이는 1991년에 제안된 Mixtures of local experts라는 논문에서 착안한 것이고 최근 구글 브레인에서 LSTM에 여러개의 피드포워드 네트워크로 구성하고 각 네트워크를 일종의 expert로 보고 그 네트워크를 모아 mixture of experts로 구성하였다. gating network에서 각 토큰의 expert probabilty(E_i(x))를 출력하고 이 분포는 softmax결과((G(x)_i), i expert의 가중치)가 곱해져 계산된다. 여기서 몇 명(몇개)의 expert(network)를 선택하게 되는데 선택할 때 다양한 전략이 있고(이후 언급할 예정) 출력은 단순히 선택한 모든 expert의 가중치 혼합물이다. 즉 위 그림의 맨 위 수식에서 볼 수 있듯 최종 결과는 expert network에서 생성되는 모든 결과의 weighted sum이 된다.

 

 

Switch Transformer: A Simple and Scalable Approach

Switch Transformer는 트랜스포머와 동일하게 self-attention 모듈과 feed-forward 네트워크로 구성된 구조이고, 다른점은 기존 feed-forward 레이어를 1개마다 혹은 2,4개마다 스위칭 ffn layer로 대체한다는 점이다.

하나의 FFN 레이어는 expert가 되고, 라우터는 모든 expert에 대한 분포를 얻는다. 가장 간단한 아이디어는 가장 높은 확률을 가지는 expert를 보내는 것이고, 따라서 그림에서 보면 왼쪽 토큰은 2번째 FFN, 오른쪽 토큰은 1번째 FFN의 확률을 보내는 것을 볼 수 있다. sparse switch layer는 입력이 동일해도 서로 다른 weight matrix를 가지고 있다. 

 

토큰을 여러 개의 expert에 보내는 것은 communication cost가 증가할 수 있어, 한 expert에게만 전송해 알고리즘을 단순화시킬 수 있다. sparse model을 학습하기 위해 저자들은 위 3가지에 중점을 두었다. 첫 번째는, Selective precision으로 sparse model을 낮은 정밀도 형식으로 훈련할 수 있게 하는 것이다. sparse model에서는 특히 데이터의 스케일이 커지거나 작아지는 간격이 dense model보다 클 것이다. 그럴 경우 loss 계산에서 exponential하게 커지기 때문에 학습이 더 불안정해 질 수 있으며 특히 라우터의 softmax 연산시 상황을 더 악화시킬 수 있다. 따라서 selective-precision을 사용해 비슷한 결과를 더 빠르게 학습할 수 있었다고 한다. 두 번째는, 텐서 계산을 float32 타입을 사용할 경우 속도가 느리기 때문에 몇 가지 initialize trick과 learning rate schedule 등을 통해 모델의 크기가 커져도 안정적으로 학습할 수 있도록 한다. 예를 들어 fine-tuning하는 동안 extra expert dropout을 크게 늘리면 성능이 크게 향상되었다고 한다.

세 번째는, pre-trained model을 fine tuning할 경우 모델이 이미 많은 파라미터를 가지고 있기 때문에 overfitting에 훨씬 취약하다. 이를 방지 하기 위한 differentiable load balancing techinique를 적용하고, 이는 expert가 대략적으로 동일한 양의 token을 얻을 수 있도록 하므로 하드웨어 코스트 비용이 효율적이게 된다. 

 

저자들은 Expert dropout을 위해, 토큰이 drop 되거나 계산이 적용되지 않은 토큰이 있으면 모델 성능이 저하될 수 있다고 생각해 multiple step routing 프로시저를 수행해 먼저 각 토큰을 가장 높은 확률 expert에게 보낸 후 drop된 토큰을 두 번째로 높은 확률을 가진 expert에게 보내는 식으로 프로세스를 반복해 토큰이 drop되지 않도록 보장하도록 실험을 해보았다고 한다. 하지만 흥미롭게도 실제 이 방식은 모델 성능을 empirically 향상시키지 못했다고 한다. 실험을 하는 도중에 놀라웠던 점은 token dropping이 굉장히 견고해서 dropout 비율을 크게 늘려도 성능이 향상되었다고 한다. dropped 토큰은 하나의 expert에게만 좋고 다른 expert로 rerouting할 필요가 없었다고 한다.

 

 

결론,

Mixture of Experts의 Idea를 LSTM에 적용한 것을 Transformer에 적용한 것이 Switch Transformer이고, Switch transformer는 좀더 efficient한 transformer architecture를 고안하기 위해 만들어졌다. 딥러닝을 위한 adaptive computation는 정확하게 밝혀지지는 않았지만 이를 위해 switch transformer와 같은 연구가 시작되었고, sparse model의 한계를 극복하기 위해 selective-precision, initialization & learning rate schedule trick, load balacing loss와 같은 technique이 추가적으로 적용되었다.

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.