RNN(Recurrent Neural Network)由于其递归的网络结构(如图1所示),对于处理序列建模任务具有独特的优势,因此在许多领域有着广泛的应用。如自然语言处理、语音识别等。
1.RNN的BPTT
根据RNN的网络结构可写出其基本方程:
S
t
=
δ
(
W
S
t
−
1
+
U
X
t
)
(
1
)
O
t
=
δ
(
V
S
t
)
(
2
)
S_{t} = \delta(WS_{t-1} + UX_{t}) \ \ \ \ \ \ \ (1) \\ O_{t} = \delta(VS_{t}) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2)
S
t
=
δ
(
W
S
t
−
1
+
U
X
t
)
(
1
)
O
t
=
δ
(
V
S
t
)
(
2
)
假设交叉熵为其损失函数loss:
L
=
−
∑
t
=
1
n
O
t
l
o
g
O
t
^
(
3
)
L=-\sum_{t=1}^{n}O_{t}log\hat{O_{t}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3)
L
=
−
t
=
1
∑
n
O
t
l
o
g
O
t
^
(
3
)
然后分别对W、U、V求偏导
先求V的偏导,因其偏导较为简单
∂
L
∂
V
=
∂
L
∂
O
t
⋅
∂
O
t
∂
V
(
4
)
\frac{\partial L}{\partial V}=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial V} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4)
∂
V
∂
L
=
∂
O
t
∂
L
⋅
∂
V
∂
O
t
(
4
)
再对W和U求偏导
由公式(1)可知,当前时刻的状态不仅与当前的输入有关,而且还与与前一时刻的状态有关。
对W和U运用链式求导
∂
L
∂
W
=
∂
L
∂
O
t
⋅
∂
O
t
∂
S
t
⋅
∂
S
t
∂
S
t
−
1
⋅
∂
S
t
−
1
∂
S
t
−
2
⋅
.
.
.
⋅
⋅
∂
S
1
∂
S
0
⋅
∂
S
0
∂
W
=
∂
L
∂
O
t
⋅
∂
O
t
∂
S
t
⋅
∏
k
=
1
t
∂
S
k
∂
S
k
−
1
⋅
∂
S
k
−
1
∂
W
(
5
)
\begin{aligned} \frac{\partial L}{\partial W}&=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \frac{\partial S_{t}}{\partial S_{t-1}}\cdot \frac{\partial S_{t-1}}{\partial S_{t-2}}\cdot…\cdot \cdot \frac{\partial S_{1}}{\partial S_{0}}\cdot \frac{\partial S_{0}}{\partial W}\\ &=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \prod_{k=1}^{t} \frac{\partial S_{k}}{\partial S_{k-1}}\cdot \frac{\partial S_{k-1}}{\partial W}\ \ \ \ \ (5) \end{aligned}
∂
W
∂
L
=
∂
O
t
∂
L
⋅
∂
S
t
∂
O
t
⋅
∂
S
t
−
1
∂
S
t
⋅
∂
S
t
−
2
∂
S
t
−
1
⋅
.
.
.
⋅
⋅
∂
S
0
∂
S
1
⋅
∂
W
∂
S
0
=
∂
O
t
∂
L
⋅
∂
S
t
∂
O
t
⋅
k
=
1
∏
t
∂
S
k
−
1
∂
S
k
⋅
∂
W
∂
S
k
−
1
(
5
)
同理可得
∂
L
∂
U
=
∂
L
∂
O
t
⋅
∂
O
t
∂
S
t
⋅
∂
S
t
∂
S
t
−
1
⋅
∂
S
t
−
1
∂
S
t
−
2
⋅
.
.
.
⋅
⋅
∂
S
1
∂
S
0
⋅
∂
S
0
∂
U
=
∂
L
∂
O
t
⋅
∂
O
t
∂
S
t
⋅
∏
k
=
1
t
∂
S
k
∂
S
k
−
1
⋅
∂
S
k
−
1
∂
U
(
6
)
\begin{aligned} \frac{\partial L}{\partial U}&=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \frac{\partial S_{t}}{\partial S_{t-1}}\cdot \frac{\partial S_{t-1}}{\partial S_{t-2}}\cdot…\cdot \cdot \frac{\partial S_{1}}{\partial S_{0}}\cdot \frac{\partial S_{0}}{\partial U}\\ &=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \prod_{k=1}^{t} \frac{\partial S_{k}}{\partial S_{k-1}}\cdot \frac{\partial S_{k-1}}{\partial U}\ \ \ \ \ (6) \end{aligned}
∂
U
∂
L
=
∂
O
t
∂
L
⋅
∂
S
t
∂
O
t
⋅
∂
S
t
−
1
∂
S
t
⋅
∂
S
t
−
2
∂
S
t
−
1
⋅
.
.
.
⋅
⋅
∂
S
0
∂
S
1
⋅
∂
U
∂
S
0
=
∂
O
t
∂
L
⋅
∂
S
t
∂
O
t
⋅
k
=
1
∏
t
∂
S
k
−
1
∂
S
k
⋅
∂
U
∂
S
k
−
1
(
6
)
2.RNN梯度消失与梯度爆炸
由公式(1)可知
∂
S
t
∂
S
t
−
1
=
W
⋅
σ
′
(
7
)
\frac{\partial S_{t}}{\partial S_{t-1}}=W\cdot {\sigma }’\ \ \ \ \ (7)
∂
S
t
−
1
∂
S
t
=
W
⋅
σ
′
(
7
)
sigmod函数
当公式(7)的乘积小于1时,公式(5)和公式(6)就会趋近于0,也即梯度消失;
当公式(7)的乘积大于1时,公式(5)和公式(6)就会趋近于无穷大,也即梯度爆炸;
3.LSTM解决RNN梯度问题
PS:图片来源于
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
下面公式中的标号参考该链接中图片标号。
i
t
=
σ
(
W
i
[
h
t
−
1
;
x
t
]
+
b
i
)
(
8
)
f
t
=
σ
(
W
f
[
h
t
−
1
;
x
t
]
+
b
f
)
(
9
)
C
~
t
=
t
a
n
h
(
W
c
[
h
t
−
1
;
x
t
]
+
b
c
)
(
10
)
C
t
=
i
t
∗
C
~
t
+
f
t
∗
C
t
−
1
(
11
)
o
t
=
σ
(
W
o
[
h
t
−
1
;
x
t
]
+
b
o
)
(
12
)
h
t
=
o
t
∗
t
a
n
h
(
C
t
)
(
13
)
\begin{aligned} i_{t}&=\sigma (W_{i}[h_{t-1}; x_{t}]+b_{i}) \ \ \ \ \ \ \ (8) \\ f_{t}&=\sigma (W_{f}[h_{t-1}; x_{t}]+b_{f}) \ \ \ \ \ \ (9) \\ \tilde{C}_{t}&=tanh (W_{c}[h_{t-1}; x_{t}]+b_{c}) \ \ (10) \\ C_{t}&=i_{t}*\tilde{C}_{t}+f_{t}*C_{t-1} \ \ \ \ \ \ \ \ (11) \\ o_{t}&=\sigma (W_{o}[h_{t-1}; x_{t}]+b_{o}) \ \ \ \ \ \ (12) \\ h_{t}&=o_{t}*tanh(C_{t}) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (13) \\ \end{aligned}
i
t
f
t
C
~
t
C
t
o
t
h
t
=
σ
(
W
i
[
h
t
−
1
;
x
t
]
+
b
i
)
(
8
)
=
σ
(
W
f
[
h
t
−
1
;
x
t
]
+
b
f
)
(
9
)
=
t
a
n
h
(
W
c
[
h
t
−
1
;
x
t
]
+
b
c
)
(
1
0
)
=
i
t
∗
C
~
t
+
f
t
∗
C
t
−
1
(
1
1
)
=
σ
(
W
o
[
h
t
−
1
;
x
t
]
+
b
o
)
(
1
2
)
=
o
t
∗
t
a
n
h
(
C
t
)
(
1
3
)
类比RNN中偏导的连乘部分,LSTM中连乘部分为
∂
C
t
∂
C
t
−
1
=
f
t
=
σ
(
14
)
\frac{\partial C_{t}}{\partial C_{t-1}}=f_{t}=\sigma \ \ \ \ \ \ \ \ \ (14)
∂
C
t
−
1
∂
C
t
=
f
t
=
σ
(
1
4
)
对比公式(7)和公式(14),LSTM的连乘部分变成了σ,在实际参数更新过程中,通过控制其值接近于1,则经过多次连乘(训练)后,梯度也不会消失;而σ的值不会大于1,故不会出现梯度爆炸。