LSTM
2021-12-26

 

LSTM简介

LSTM_arch

 

LSTM单元结构

LSTM单元结构

  1. 输入

    xt:输入数据

    ht1:hidden state

    ct1:cell state

  2. 候选状态

    c~:从输入数据和hidden state中提取的信息

    c~t=tanh(Wcxt+Ucht1+bc)

  3. 遗忘门:对ct1,遗忘其多少信息

    输入门:对c~,保存其多少信息

    输出门:对ct,输出其多少信息到ht

    三个门的计算方式为:

    (1)it=σ(Wixt+Uiht1+bi)ft=σ(Wfxt+Ufht1+bf)ot=σ(Woxt+Uoht1+bo)σ()Logisticxtht1hiddenstate
  4. 输出

(2)ct=ftct1+itc~tht=ottanh(ct)

 

Pytorch LSTM

模型参数

  1. input_size: 输入数据的特征个数
  2. hidden_size: hidden state的个数
  3. num_layers: RNN网络的层数,以下默认为1
  4. batch_first: 如果为true,则输入数据的shape为(batch_size, seq_len, feature_num)。以下默认为true。

输入

  1. input

    batch_size * seq_len * feature_num

    X_shape

  2. h_0

    num_layers * batch_size * hidden_size

  3. c_0

    num_layers * batch_size * hidden_size

输出

  1. output

    batch_size * seq_len * hidden_size

  2. h_n:每个batch最后一个时刻的hidden state

    num_layers * batch_size * hidden_size

  3. c_n:每个batch最后一个时刻的cell state

    num_layers * batch_size * hidden_size

     

参考文档

  1. LSTM — PyTorch 1.10.1 documentation
  2. Understanding LSTM Networks -- colah's blog