BERT参数计算
2022-10-03

 

BERT Model

 

BERT参数计算

Embedding

  1. Token Embedding

    token个数 * 词向量维度 = 30522 * 768

  2. Segment Embedding

    只有0和1区分上下句子 = 2 * 768

  3. Position Embedding

    文本输入最长大小 * 词向量维度 = 512 * 768

     

所以,embedding参数 = (30522 + 2 + 512)* 768 = 23835648

 

Multi-Heads Attention

(1)Qi=XWiQKi=XWiKVi=XWiVheadi=Attention(Qi,Ki,Vi)=softmax(QiKiTdk)ViMultiHead(Q,K,V)=Concat(head1,head2,,headh)WO

X维度:[30522,768]

dmodel大小:768

dk=dv=dmodel/h大小:768/12=64

权重矩阵WQWKWV的维度:[768, 64]

QKV的维度:[30522, 64]

headi的维度:[30522, 64]

WO的维度:[768, 768]

BERT中head个数:12

BERT中Encoder(或看认为一次Multi Head Attention)个数:12

 

  1. 对1个head,权重矩阵(WQWKWV)的参数维度为:768 * 64 * 3
  2. 对1个head,线性变换矩阵(WO)参数为:768 * 768
  3. 对1个Encoder(即包含12个head),权重矩阵参数维度:768 * 64 * 3 * 12+768*768 = 2359296
  4. 对12个Encoder,参数为:2,359,296 * 12 = 28,311,552 = 27MB

FeedForward

(2)FFX(x)=max(0,xW1+b1)W2+b2

该步骤使用了W1W2两个参数,Bert沿用了惯用的全连接层大小设置,即4*dmodel=4*768=3072。

因此,W1W2的维度为: (768, 3072) 和 (3072, 768)

12层全连接层的参数为:12 * (768 * 3072 * 2) = 54MB (未考虑bias)

LayerNormalization

layer normalization有两个参数,分别是gamma和beta。有三个地方用到了layer normalization,分别是embedding层后、multi-head attention后、feed forward后,这三部分的参数为768*2+12*(768*2+768*2)=38400

NSP和MLM

参数量较小,忽略。

 

参数合计

Embedding: (30522 + 2 + 512)* 768 = 23835648

Multi-head Attention:(768*64 * 12 * 3 + 768 * 768)* 12 = 28311552

FeedForward:12 * (2 * 768 * 3072) = 56623104

LayerNorm:768 * 2 + (768 * 2)*12 + (768 * 2)*12 = 38400

 

Total = 23835648 + 28311552 + 56623104 + 38400 = 108808704

 

参考文档

  1. BERT参数量如何计算 - 知乎 (zhihu.com)

  2. 【NLP】BERT原理 - 知乎 (zhihu.com)