深入理解LSTM:从结构到PyTorch实践

张开发
2026/4/18 22:02:36 15 分钟阅读

分享文章

深入理解LSTM:从结构到PyTorch实践
1. 引言为什么需要LSTM循环神经网络RNN因其天然的时序结构被广泛应用于自然语言处理、时间序列预测等任务。然而传统RNN在处理长序列时容易遭遇梯度消失或梯度爆炸问题导致模型难以捕捉远距离的语义依赖。例如在“我出生在法国……我会说法语”中“法语”依赖于远在前面的“法国”传统RNN往往难以建立这种长距离关联。为了解决这一问题Sepp Hochreiter和Jürgen Schmidhuber于1997年提出了长短期记忆网络Long Short-Term Memory, LSTM。LSTM通过精巧的门控机制和细胞状态选择性地记忆或遗忘信息从而有效缓解了长序列训练中的梯度消失问题。后来2000年左右Gers等人又引入了遗忘门进一步完善了LSTM结构。2. LSTM的核心思想LSTM与传统RNN最大的区别在于它引入了一条细胞状态Cell State的“传送带”信息可以在时间步上几乎无损地流动。同时LSTM使用三个门控单元遗忘门、输入门、输出门来控制信息的遗忘、写入和读出。细胞状态Ct负责长期记忆贯穿整个序列。隐状态ht负责短期记忆也是每个时间步的输出。门使用sigmoid函数输出0~1之间的值表示信息“通过”的比例0表示完全阻断1表示完全通过。3. LSTM内部结构详解含公式下图示意了单个LSTM单元的内部结构图中省略了偏置项但在实际实现中存在。3.1 遗忘门Forget Gate遗忘门决定上一时刻的细胞状态 Ct−1 中有多少信息需要被丢弃。它读取当前输入 xt和上一时刻隐状态 ht−1输出一个0~1的向量 ft​。σ 为sigmoid函数。[ht−1​,xt​] 表示将两个向量拼接。Wf​ 和 bf​ 为可学习参数。直观理解如果 ft中的某个分量接近0则对应的历史信息将被遗忘接近1则保留。3.2 输入门Input Gate输入门决定将多少新信息写入细胞状态。它由两部分组成门控部分iti决定哪些位置要更新。候选细胞状态C~t利用tanh层生成新的候选值向量。tanh 将输出值压缩到-1到1之间起到调节作用。3.3 细胞状态更新旧细胞状态 Ct−1经过遗忘门进行选择性遗忘再与输入门筛选后的候选状态相加得到新的细胞状态 Ct。表示逐元素相乘Hadamard积。意义这一步完美融合了“忘记过去不重要的”和“记住当前新的重要信息”。3.4 输出门Output Gate输出门决定当前时刻的隐状态 htht​同时也是该时刻的输出。它基于更新后的细胞状态 CtCt​并经过一个门控筛选。先用tanh将 Ct 的值缩放至-1~1再通过输出门 ot​ 决定哪些信息最终输出。总结LSTM通过上述四个步骤实现了对长序列信息的选择性存储和读取。其中遗忘门和输入门配合完成细胞状态的更新输出门控制隐状态的表达。4. PyTorch中的LSTM实现PyTorch提供了便捷的torch.nn.LSTM模块我们可以直接调用。4.1 参数说明nn.LSTM(input_size, hidden_size, num_layers1, biasTrue, batch_firstFalse, dropout0, bidirectionalFalse)input_size输入特征维度例如词向量的长度。hidden_size隐状态 htht​ 的维度。num_layersLSTM堆叠的层数大于1时为多层LSTM。batch_first若为True输入形状为(batch, seq_len, input_size)否则为(seq_len, batch, input_size)。注意bidirectional参数在这里应保持False本文不涉及Bi-LSTM。4.2 输入与输出形状输入input形状(seq_len, batch, input_size)h0可选初始隐状态形状(num_layers, batch, hidden_size)c0可选初始细胞状态形状(num_layers, batch, hidden_size)输出output所有时间步的隐状态形状(seq_len, batch, hidden_size)(hn, cn)最后一个时间步的隐状态和细胞状态形状均为(num_layers, batch, hidden_size)4.3 完整示例# 定义LSTM的参数含义: (input_size, hidden_size, num_layers) # 定义输入张量的参数含义: (sequence_length, batch_size, input_size) # 定义隐藏层初始张量和细胞初始状态张量的参数含义: # (num_layers * num_directions, batch_size, hidden_size) import torch.nn as nn import torch rnn nn.LSTM(5, 6, 2) input torch.randn(1, 3, 5) h0 torch.randn(2, 3, 6) c0 torch.randn(2, 3, 6) output, (hn, cn) rnn(input, (h0, c0)) output tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416], [ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548], [-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]], grad_fnStackBackward) hn tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152], [ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477], [ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]], [[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416], [ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548], [-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]], grad_fnStackBackward) cn tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161], [ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626], [ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]], [[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828], [ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983], [-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]], grad_fnStackBackward)在实际任务中如情感分析我们通常取output[:, -1, :]作为最后一个时间步的隐状态再接入全连接层进行分类。5. LSTM的优缺点✅ 优势长距离依赖建模能力强相比传统RNNLSTM通过门控机制有效缓解了梯度消失/爆炸可以处理长达数百步的序列。灵活性高可以堆叠多层也可以与其他网络如CNN、Attention结合。工程成熟各种深度学习框架均有高效实现且有很多预训练变体。❌ 缺点计算复杂度高每个时间步需要计算4个全连接层遗忘门、输入门、输出门、候选状态参数量约为传统RNN的4倍训练和推理较慢。难以并行LSTM本质是递归结构后一个时间步依赖前一步的输出无法像Transformer那样进行大规模并行计算。并非万能在超长序列数千步上仍有信息衰减且对随机打乱的序列不敏感。6. 总结LSTM是RNN家族中最经典、最成功的变体之一。它通过遗忘门、输入门、输出门和细胞状态实现了对长期记忆的精细控制解决了原始RNN的梯度问题。虽然近年来Transformer等模型在多数NLP任务上取得了更好效果但LSTM在时间序列预测、语音识别、小规模序列建模等场景中依然具有重要价值。掌握LSTM的内部原理和PyTorch实现是深入理解序列模型的关键一步。参考文献Hochreiter, S., Schmidhuber, J. (1997). Long short-term memory.Neural computation, 9(8), 1735-1780.Gers, F. A., Schmidhuber, J., Cummins, F. (2000). Learning to forget: Continual prediction with LSTM.Neural computation, 12(10), 2451-2471.希望本文能帮助你彻底搞懂LSTM如果有任何疑问欢迎在评论区留言讨论。

更多文章