LSTM2021-12-26
LSTM简介

LSTM单元结构

输入
:输入数据
:hidden state
:cell state
候选状态
:从输入数据和hidden state中提取的信息
门
遗忘门:对,遗忘其多少信息
输入门:对,保存其多少信息
输出门:对,输出其多少信息到
三个门的计算方式为:
输出
Pytorch LSTM
模型参数
- input_size: 输入数据的特征个数
- hidden_size: hidden state的个数
- num_layers: RNN网络的层数,以下默认为1。
- batch_first: 如果为true,则输入数据的shape为(batch_size, seq_len, feature_num)。以下默认为true。
输入
input
batch_size * seq_len * feature_num

num_layers * batch_size * hidden_size
num_layers * batch_size * hidden_size
输出
output
batch_size * seq_len * hidden_size
:每个batch最后一个时刻的hidden state
num_layers * batch_size * hidden_size
:每个batch最后一个时刻的cell state
num_layers * batch_size * hidden_size
参考文档
- LSTM — PyTorch 1.10.1 documentation
- Understanding LSTM Networks -- colah's blog