1
RNN 是什么?
RNN 的全称是 Recurrent Neural Network(循环神经网络) 。它是一种专门用于处理 序列数据 的神经网络架构。
核心思想:让网络拥有“记忆”
传统的神经网络(如全连接网络、CNN)假设输入之间是相互独立的——它们一次处理一个输入,下一个输入和上一个输入没有关系。但现实中有大量数据是 前后依赖 的:
RNN 通过引入 循环连接 ,使得网络可以保留之前的信息,并将其应用于当前的计算。这个保留的信息叫做 隐藏状态 (Hidden State),可以理解为网络的“记忆”。
一个形象的比喻
想象你在阅读一本小说:
RNN 的工作方式正是如此:它有一个内部状态,每读一个输入,就更新一次状态,然后用这个状态来帮助理解当前输入,并产生输出。
2
RNN 的核心特点
特点 | 说明 | 意义 |
|---|---|---|
处理变长序列 | 可以接受任意长度的输入序列 | 非常灵活,适用于文本、语音、时间序列等 |
参数共享 | 在每个时间步,RNN 使用相同的权重矩阵 | 大大减少了参数量,也体现了“同一个处理机制适用于所有时间步”的思想 |
隐藏状态 | 维护一个状态向量,传递序列中的信息 | 实现了“记忆”功能,让网络能够利用历史信息 |
时间依赖性 | 当前输出不仅取决于当前输入,还依赖于过去的信息 | 能够建模序列中的时序依赖关系 |
3
RNN 的工作流程
1. 基本结构
一个简单的 RNN 单元可以用以下公式表示:

2. 按时间展开(Unfolding)
为了理解 RNN 的训练过程,通常会将 RNN 按时间步“展开”成一个很深的 前馈网络 。例如,对于一个长度为 3 的序列,展开后的结构如下:
y1 y2 y3
↑ ↑ ↑
h1 ← h2 ← h3
↑ ↑ ↑
x1 x2 x3注意:虽然展开后看起来很“深”,但所有时间步的权重 $W{hh},W{xh},W_{hy}$ 都是 共享 的,所以参数量不会随着序列长度增加。
3. 训练方法:BPTT(通过时间反向传播)
RNN 的训练使用一种称为 BPTT 的算法,它本质上就是在展开后的网络上应用标准的反向传播。大致步骤:
4. 三种常见的使用模式
根据输入和输出的形式,RNN 可以有多种应用模式:
模式 | 图示 | 典型应用 |
|---|---|---|
多对一 | 输入序列 → 一个输出 | 情感分类(将一句话映射为正面/负面) |
一对多 | 一个输入 → 输出序列 | 图像描述(输入一张图,输出一段描述) |
多对多(同步) | 输入序列 → 输出序列(长度相同) | 词性标注、视频帧分类 |
多对多(异步) | 输入序列 → 输出序列(长度可变) | 机器翻译(编码器-解码器架构) |
4
RNN 的变体:LSTM 和 GRU
简单的 RNN 在处理长序列时有一个严重的问题: 梯度消失或梯度爆炸 。由于反向传播穿过多个时间步,梯度会反复相乘,导致长距离的信息难以被学习。为了解决这个问题,研究者提出了两种著名的变体。
1. LSTM(长短期记忆网络)
LSTM 通过引入 门控机制 来控制信息的流动:
LSTM 还有一个 细胞状态 (Cell State),它像一条传送带,可以让信息在长序列中几乎无损地传递。这使得 LSTM 能够捕捉长达几百步的依赖关系。
2. GRU(门控循环单元)
GRU 是 LSTM 的简化版本,将遗忘门和输入门合并为一个 更新门 ,并将细胞状态和隐藏状态合并。它参数更少,训练更快,效果与 LSTM 相当,因此在很多任务中成为首选。
5
如何使用 RNN?(以 PyTorch 为例)
在现代深度学习框架中,使用 RNN 非常简单。以下是一个使用 PyTorch 构建简单 RNN 进行文本分类的示意代码:
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# 初始化隐藏状态
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# 前向传播 RNN
out, _ = self.rnn(x, h0) # out: (batch_size, seq_length, hidden_size)
# 只取最后一个时间步的输出进行分类
out = out[:, -1, :]
out = self.fc(out)
return out6
RNN 的优缺点
优点
缺点
7
RNN 与 CNN、Transformer 的关系
模型 | 擅长处理 | 核心机制 | 并行 | 长距离依赖 |
|---|---|---|---|---|
CNN | 空间结构数据(图像) | 卷积、局部连接 | 强 | 需堆叠层数扩大感受野 |
RNN | 时序序列数据 | 循环、状态传递 | 弱 | 有梯度问题,LSTM 部分缓解 |
Transformer | 任意序列(文本、图像、视频) | 自注意力、位置编码 | 强 | 全局依赖直接建模 |
在实际应用中,RNN 逐渐被 Transformer 取代,但在以下场景中 RNN 仍有价值: