ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [NLP] Hidden Markov Transformer for Simultaneous Machine Translation
    ✨ AI/AI papers 2023. 4. 20. 23:04

    Hidden Markov Transformer for Simultaneous Machine Translation

     

    Link : https://arxiv.org/pdf/2303.00257.pdf

     

     

    Abstract

    Simultaneous machine translation(SiMT) task에서, 언제 translation을 시작할지에 대한 많은 가능한 moments 사이에서 optimal moment를 learning하는 것은 non-trivial하다. 왜냐하면, 번역 시작점은 항상 모델 안에 숨겨져 있고, 오직 관찰된 target sequence에 대해서만 supervised learning이 가능하기 때문이다. 따라서 해당 논문에서는, Hidden Markov Transformer(HMT)를 제안하며, HMT는 번역 시작점을 hidden event로써, 그리고  타겟 시퀀스를 관찰된 이벤트를 따르는 것으로 다뤄서 이들을 organizing하는 모델이다. HMT는 여러개의 번역 시작점을 후보 Hidden events로 explicitly하게 모델링하고, 이 중 타겟 토큰을 생성하기 위해 하나를 select한다. training하는 동안, multiple 번역 시작점에서의 target sequence의 marginal likelihood를 maximizing함으로써 HMT는 타겟토큰이 더 정확하게 생성되는 지점에서 번역을 시작하는 것을 learning하게 된다.

     

    Introduction

    Simultaneous machine translation은 소스 토큰을 하나 씩 받으면서 target token을 simultaneously 생성하고, 언제 translation을 하면 좋을지에 대한 wise decision이 필요하다. 그러나 언제 번역을 시작할지를 learning하는 것은 그 시작점을 항상 모델에 숨겨져 있고 우리는 오직 관찰된 sequence를 supervised할 수 있기 때문에, SiMT model을 위해서는 non-trivial한 문제이다. 지금까지 존재하는 SiMT method로는 언제 번역을 시작할지 결정하는 것을 fixed하게 나누거나, adaptive하게 나누는 방법이 있다. Fixed methods는 이 시점을 learning하지 않고 미리 정의된 룰에 맞춰 바로 언제 시작할지 결정하는 방법이다. 이는 context를 무시하고, 또한 가끔 소스가 충분하지 않을 때도 모델이 번역을 시작하도록 강제하게 된다. Adaptive methods는 READ/WRITE action을 나타내기 위한 variable을 prediction하는 것과 같이 READ/WRITE action을 동적으로 결정한다. 그러나 READ/WRITE action 과 관찰된 target sequence 사이에 명백한 관련성이 부족하다는 것 때문에 정확한 action을 오직 관찰된 target sequence의 supervision으로 learning하는 것은 매우 어려운 일이다. 따라서 모델 안에 숨겨진 각 타겟 토큰의 번역을 시작하는 optimal monent를 찾기 위한 ideal solution은, 관찰된 타겟과 번역 시작점의 명백한 연관성을 찾는 것이며 더 나아가서는 정확한 토큰이 생성될 수 있는 지점을 시작점으로 learning하는 것이다. 

    따라서 해당 논문에서는 Hidden Markov Transformer(HMT)를 제안하고 이는 번역 시작점을 hidden event로, translated result를 관찰된 event로 다루며 이들을 organizing하는 모델이다. 

    위 Figure 1에서 볼 수 있듯이, HTM는 explicitly 각 타겟 토큰을 위한 states set을 만들어내고, 이 states set은 서로 다른 각각의 시점에서 타겟 토큰 번역 시작점을 표현하는 multiple states이다. 즉, 서로 다른 길이의 소스 토큰을 받은 후, 번역한 시점을 나타내며 그리고 HMT는 낮은 latency 부터 high latency까지 오는 각 state를 선택할지 말지 판단하게 된다. 그리고 일단 state가 선택되면 선택된 state에 근거해 target token을 생성하게 된다. 예를 들어서, 소스 토큰 갯수를 1, 2, 3을 받은 3개의 state가 있다고 할 때(각 state의 hidden은 누적), 첫번째 state가 선택되지 않고, 두번째 state가 선택되어 output "I"를 생성했으면 3번째 state는 더이상 고려되지 않는다. training동안 HTM는 모든 가능한 selected results(i.e. hidden events) 타겟 시퀀스의 marginal likelihoode를 maximizing하는 방향으로 optimizing된다. 이러한 방법으로 target token을 생성하는 states(moments)들은 더욱 정확하게 selected되고, HMT는 효과적으로 관찰된  target sequence의 supervision을 통해 언제 번역을 시작할지 learning하게 된다.

     

     

    Hidden Markov Transformer

    Architecture

    HMT는 encoder와 hidden Markov decoder로 구성되어 있으며, encoder는 unidirectional encoder를 사용하였다. encoder는 받은 source의 hidden state h를 생성하고, Markov decoder는 여러 시점에서 만들어진 target token y를 생성하기 위한 set of states를 만들어내고, 어떤 state가 선택될 것인지 결정하는 역할을 한다. HMT는 3개의 파트로 구성되어있고 각각 state production, translation, selection이다.

     

    State Production

    target token y_i를 번역할 때, HMT는 K states set s_i를 만들어낸다. translating moments 를 t_i={t_{i, 1}, .. t_{i, K}}로 정의된다. 적절한 t state set을 위해서, 선호되지 않는 translating moment를 pre-prune했다. 예를 들어, x_j를 받은 이후 y_1을 translating한다거나 x_1을 받은 후 y_I를 translating한다거나.

    위그림처럼 wait-L path를 lower boundary로 정의하고 wait-(L+K-1)을 적절한 path의 upper boundary로 정의했으며 L과 K는 hyperparameter이다. Figure3에서 보면 L=1, K=4라면, y_1을 번역하려고 할 때 가능한 state set은 src token이 1부터 4인경우까지 총 4가지 set이 있는 것이다. lower boundary path는 소스 토큰이 1개씩 추가될 때마다 번역하는 것이고 upper boundary는 소스토큰 4개를 기다렸다가 1개씩 추가될 때마다 번역하는 것이다. 

     

    Tranaslation

    K states의 representation은 타겟 인풋 K times를 upsampling하여 초기화되며, N개의 Transformer decoder layer에 의해 계산된다. 

    그리고 State간 Self Attention은, unidirectional encoder이기 때문에, 각 state는 이전 state와의 attention만 계산되고,

    state s_{i, k}와 source hidden state h_j의 Cross Attention은 위와 같이 source hidden state step j가 source token t_{i,k}보다 작은 경우에 대한 attention을 계산한다.

    그리고 N개의 decoder layer를 통해  state s_{i, k}의 최종 representation을 얻고 이 final representation으로 부터 계산된 y_i의 prob을 얻는다. W^O는 W^Q, W^K와 마찬가지로 learnable parameter이고 y_{<i}는 이전에 얻어진 target tokens이다.

     

    Selection

    위 단계까지에서 얻은 state의 최종 representation을 얻은 후, state s_{i, k}를 선택(y_i 번역)할지 말지를 위해 HMT는 state s_{i, k}를 선택하는 confidence c_{i, k}를 예측한다.

    c_{i, k}는 위 수식처럼 s_{i, k}의 최종 representation과 지금까지 받은 source contents를 받아 learnable parameter W^S에 의해 예측된다. 

    그리고 \bar_h는 받은 소스 토큰의 hidden states를 average polling한 것이고 [:]는 concatenating 연산을 의미한다. confidence c가 0.5 이상이면 HMT는 해당 state를 선택하고 그렇지 않으면 다음 state로 넘어가며 이를 반복한다. 

     

     

    Training

    선택된 state를 z라고 하면 Transition probability는 이전 selection z_{i-1}이 condition된 selection z_i의 probability로 포현된다.(p(z_i | z_{i-1}))  s_{i, k+1}은 오직 이전 state s_{i, k}가 선택도지 않았을 때만 선택될 수 있으므로 p(z_i| z_{i-1})을 계산하는 것은 2가지 파트로 나누어진다.

    (1) s_{i, z_i}는 선택된 confident이고

    (2) t_{i-1, z_{i-1}}과 t_{i, z_i} 사이의 states들은 선택될 confident가 아니다.

     

    그리고 emission probability는 state s_{i, z_i}에 의한 관찰된 y_i의 확률로 표현된다.

     

    HMM loss

    언제 번역을 시작할지 learning하기 위해 HMT는 모든 가능한 선택된 결과에서의 target sequence의 marginal likelihood를 maximizing하는 방향으로 training된다. 그러므로 HMT는 타겟 토큰을 더 정확하게 생성하는 state를 선택하는 쪽에 높은 confidence를 주게 된다.

    따라서 Transition probability와 Emission probability가 주어졌을 때, 관찰된 target sequence y의 marginal liklihoode는 위와 같은 수식으로 계산된다. 모든 가능한 selection results의 marginalizing을 계산하기 위해 dynamic programming을 사용해 computational complexity를 줄였다.

    그리고 구한 marginal likelihood p(y|x)는 데이터셋의 target sequence에 대해 negative log-likelihoode loss로 HMT를 optimize하게 된다.

     

    Latency loss

    그리고 translation quality와 latency의 tradeoff 를 위해 latency loss L_latency를 도입했다.

    C(z)는 z에서 계산되는 latency 함수이고, 여기서는 average lagging을 사용했다. 

     

    State loss

    그리고 어떤 State가 선택되어도 정확한 target token을 생성하여 모델의 robustness를 향상할 수 있는 state loss L_state를 추가로 사용하였다. K개의 states가 병렬적으로 hidden Markov Decoder로 feed되기 때문에 L_state는 모든 states에 대한 cross-entropy로 계산된다.

     

    그래서 최종 loss는 hmm loss, latency loss, state loss를 모두 합한 것으로 정의되며 lambda는 hyperparameter로 여기서는 모두 1을 사용했다고 한다.

     

     

    Inference

     

     

    Experiments

    데이터셋은 IWSLT15 En->Vi와 De-> En를 사용했으며 성능 비교를 위한 Baseline model은 Full-sentence MT, Wait-k, Multipath Wait-k, Adaptive Wait-k, MoE Wait-k, MMA, GMA, GSiMT가 사용되었다. 

     

    Main Results

    실험 결과는 SOTA모델인 MMA, GSiMT를 outperforms하였고 두개의 advantages를 갖는다고 저자들은 주장하고 있다. 첫번째로는, HMT는 READ/WRITE action과 비교해서 관찰된 target sequence와 번역 시작점이 더 강력한 correlations를 갖는다는 점과 두번째로는, MMA, GSiMT 모델 모두 오직 READ/WRITE acton의 combination을 고려하지만 HMT는 multiple moments를 고려하고 서로 다른 translating multiple states를 만들어낼 수 있다는 점이다.

    그리고 특히 MMA, GSiMT는 특정 latency에서 성능이 낮은 것을 볼 수 있는데 이는 너무 많은 READ/WRITE action의 combination은 mutual interference를 야기시키기 때문이라고 한다. 반면 HMT는 unfavorable moments를 pre-prune함으로써 모든 latency regime에 대해서 좋은 성능을 보여준다.

     

    Training and Inference Speed

    그리고 이는 Online model이기 때문에 Training, Inference 스피드가 중요한데, pre-prune 덕에 기존 wait-k 모델보다는 아주 살짝 느려도 MMA, GSiMT에 비해서는 매우 빠른 training and infernece time을 보여주고 있다.

     

     

    Analysis

    Ablation Study

    Table 2 HMM loss에서 z의 summation을 Max_z로 바꿔서 가장 선호하는 selection으로 optmizing하도록 바꾸었더니 모델이 local optimum에 빶지는 경향이 있었고 wait-3와 비슷한 성능을 주었다고 한다. 따라서 저자들은 HMT가 multiple possible moments로 부터 번역 시작점을 효과적으로 learning하기 위해서는 가장 선호하는 selection보다 가능한 selection의 marginalizing이 성능이 더 좋았음을 보여주었다.

    Table 3,4 그리고 L_latency와 L_state의 constant 상수를 바꿔가면서 실험을 진행했을 때 둘 다 람다를 1로 두었을 때 성능이 좋았음을 보였다. 특히 lambda_latency를 너무 크거나 작게 주면 모델이 너무 lower bound결과와 가까워지거나 upper bound결과에 가까워지는 결과를 보였다.

     

    How Many States Are Better?

    K가 작을 수록 낮은 latency에서 좋았으며 K가 커질 수록 latency가 커지는 경향을 보였다. 이는 translating moments의 큰 gap은 각기 다른 training동안 간섭을 일으키기 때문이며 이 gap은 low latency일 수록 더 두드러지게 관찰 되었다. 

     

     

     

     

     

    댓글

Designed by Tistory.