Luong Attention
2022-05-07

 

Attention 架构

Luong Attention分为Local和Global两种,本文主要分析Global Attention。下图为Global Attention的架构图:

global_attention

符号解释:

  1. h¯s:encoder_output
  2. ht:decoder_output
  3. at(s):attn_weights
  4. ct:Context vector
  5. h~t:attentional hidden state

 

以中英文翻译场景为例,根据该架构图,分为如下计算步骤:

  1. attn_weights计算

    通过encoder_output和decoder_output计算得到attn_weights,即at

  2. Context vector计算

    通过encoder_output和at计算得到加权encoder_output,即ct

  3. attentional hidden state计算

    ct和decoder_output进行cat合并,经过tanhlinear变换处理得到attentional hidden state,即h~t

  4. 预测

    根据h~t进行预测,得到最终预测结果。

 

Attention计算流程

Global Attention计算流程如下图所示:

Luong_Attention

计算步骤如下:

  1. Encoder

    • step1:对原始输入,通过RNNLSTM处理,得到encoder_output:

      h¯s:[batch_size,input_len,enc_hidden_size]

    • step2:为了能够使得encoder_output和decoder_output做bmm运算,需要进行linear处理

      h¯s:[batch_size,input_len,dec_hidden_size]

    • step3:transpose(1,2)

      h¯s:[batch_size,dec_hidden_size,input_len]

  2. Decoder

    • RNN

      • step4:对Decoder端的输入,通过RNNLSTM处理,得到decoder_output:

        ht:[batch_size,output_len,dec_hidden_size]

    • Attention

      • step5:基于hth¯s,进行打分计算:

        score(ht,h¯s)=bmm(ht,h¯s):[batch_size,output_len,input_len]

      • step6: 对打分结果通过softmax计算,得到attn_weights

        at(s)=align(ht,h¯s)=softmax(score(ht,h¯s)):[batch_size,output_len,input_len]

      • step7: 基于at(s),对encoder_output求加权平均

        ct=ath¯s:[batch_size,output_len,enc_hidden_size]

    • Attentional hidden state

      • step8:将加权encoder_output通过cat操作”融入”到原始的decoder_output

        [ct;ht][batch_size,output_len,enc_hidden_size+dec_hidden_size]

        为了方面后续的linear变换,需要对其shape进行调整,结果如下:

        [ct;ht][batch_size×output_len,enc_hidden_size+dec_hidden_size]

      • step9:对[ct;ht]进行linear变换

        linear([ct;ht])[batch_size×output_len,dec_hidden_size]

      • step10:tanh变换

        tanh(linear([ct;ht]))[batch_size×output_len,dec_hidden_size]

      • step11:将二维展开到三维,得到最终的attentional hidden state

        h~t[batch_size,output_len,dec_hidden_size]

    • Predict

      • step12:对h~t进行linear变换

        linear(h~t)[batch_size,output_len,vocab_size]

      • step13:通过softmax运算,得到最终预测概率结果

        softmax(linear(h~t))[batch_size,output_len,vocab_size]

 

参考资料

  1. 第七课 Seq2Seq与Attention (julyedu.com)