LSTM解决RNN梯度消失与梯度爆炸问题

  • Post author:
  • Post category:其他


RNN(Recurrent Neural Network)由于其递归的网络结构(如图1所示),对于处理序列建模任务具有独特的优势,因此在许多领域有着广泛的应用。如自然语言处理、语音识别等。



1.RNN的BPTT

图1 RNN网络结构

根据RNN的网络结构可写出其基本方程:





S

t

=

δ

(

W

S

t

1

+

U

X

t

)

       

(

1

)

O

t

=

δ

(

V

S

t

)

                       

(

2

)

S_{t} = \delta(WS_{t-1} + UX_{t}) \ \ \ \ \ \ \ (1) \\ O_{t} = \delta(VS_{t}) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2)







S











t





















=








δ


(


W



S











t





1





















+








U



X











t



















)
















(


1


)









O











t





















=








δ


(


V



S











t



















)
















































(


2


)







假设交叉熵为其损失函数loss:





L

=

t

=

1

n

O

t

l

o

g

O

t

^

                     

(

3

)

L=-\sum_{t=1}^{n}O_{t}log\hat{O_{t}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3)






L




=






















t


=


1



















n





















O











t



















l


o


g











O











t
























^


























































(


3


)







然后分别对W、U、V求偏导

先求V的偏导,因其偏导较为简单





L

V

=

L

O

t

O

t

V

                

(

4

)

\frac{\partial L}{\partial V}=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial V} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4)




















V

















L






















=























O











t


































L













































V


















O











t





































































(


4


)







再对W和U求偏导

由公式(1)可知,当前时刻的状态不仅与当前的输入有关,而且还与与前一时刻的状态有关。

对W和U运用链式求导





L

W

=

L

O

t

O

t

S

t

S

t

S

t

1

S

t

1

S

t

2

.

.

.

S

1

S

0

S

0

W

=

L

O

t

O

t

S

t

k

=

1

t

S

k

S

k

1

S

k

1

W

     

(

5

)

\begin{aligned} \frac{\partial L}{\partial W}&=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \frac{\partial S_{t}}{\partial S_{t-1}}\cdot \frac{\partial S_{t-1}}{\partial S_{t-2}}\cdot…\cdot \cdot \frac{\partial S_{1}}{\partial S_{0}}\cdot \frac{\partial S_{0}}{\partial W}\\ &=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \prod_{k=1}^{t} \frac{\partial S_{k}}{\partial S_{k-1}}\cdot \frac{\partial S_{k-1}}{\partial W}\ \ \ \ \ (5) \end{aligned}






























W

















L





















































=



















O











t


































L










































S











t



































O











t



























































S











t





1



































S











t



























































S











t





2



































S











t





1












































.


.


.



























S











0



































S











1


























































W


















S











0















































=



















O











t


































L










































S











t



































O











t





















































k


=


1



















t



































S











k





1



































S











k


























































W


















S











k





1















































(


5


)
























同理可得





L

U

=

L

O

t

O

t

S

t

S

t

S

t

1

S

t

1

S

t

2

.

.

.

S

1

S

0

S

0

U

=

L

O

t

O

t

S

t

k

=

1

t

S

k

S

k

1

S

k

1

U

     

(

6

)

\begin{aligned} \frac{\partial L}{\partial U}&=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \frac{\partial S_{t}}{\partial S_{t-1}}\cdot \frac{\partial S_{t-1}}{\partial S_{t-2}}\cdot…\cdot \cdot \frac{\partial S_{1}}{\partial S_{0}}\cdot \frac{\partial S_{0}}{\partial U}\\ &=\frac{\partial L}{\partial O_{t}}\cdot \frac{\partial O_{t}}{\partial S_{t}}\cdot \prod_{k=1}^{t} \frac{\partial S_{k}}{\partial S_{k-1}}\cdot \frac{\partial S_{k-1}}{\partial U}\ \ \ \ \ (6) \end{aligned}






























U

















L





















































=



















O











t


































L










































S











t



































O











t



























































S











t





1



































S











t



























































S











t





2



































S











t





1












































.


.


.



























S











0



































S











1


























































U


















S











0















































=



















O











t


































L










































S











t



































O











t





















































k


=


1



















t



































S











k





1



































S











k


























































U


















S











k





1















































(


6


)
























2.RNN梯度消失与梯度爆炸

由公式(1)可知





S

t

S

t

1

=

W

σ

     

(

7

)

\frac{\partial S_{t}}{\partial S_{t-1}}=W\cdot {\sigma }’\ \ \ \ \ (7)





















S











t





1



































S











t







































=








W















σ

































(


7


)







sigmod函数

图2 sigmod函数

当公式(7)的乘积小于1时,公式(5)和公式(6)就会趋近于0,也即梯度消失;

当公式(7)的乘积大于1时,公式(5)和公式(6)就会趋近于无穷大,也即梯度爆炸;



3.LSTM解决RNN梯度问题

在这里插入图片描述

PS:图片来源于

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

下面公式中的标号参考该链接中图片标号。





i

t

=

σ

(

W

i

[

h

t

1

;

x

t

]

+

b

i

)

       

(

8

)

f

t

=

σ

(

W

f

[

h

t

1

;

x

t

]

+

b

f

)

      

(

9

)

C

~

t

=

t

a

n

h

(

W

c

[

h

t

1

;

x

t

]

+

b

c

)

  

(

10

)

C

t

=

i

t

C

~

t

+

f

t

C

t

1

        

(

11

)

o

t

=

σ

(

W

o

[

h

t

1

;

x

t

]

+

b

o

)

      

(

12

)

h

t

=

o

t

t

a

n

h

(

C

t

)

                 

(

13

)

\begin{aligned} i_{t}&=\sigma (W_{i}[h_{t-1}; x_{t}]+b_{i}) \ \ \ \ \ \ \ (8) \\ f_{t}&=\sigma (W_{f}[h_{t-1}; x_{t}]+b_{f}) \ \ \ \ \ \ (9) \\ \tilde{C}_{t}&=tanh (W_{c}[h_{t-1}; x_{t}]+b_{c}) \ \ (10) \\ C_{t}&=i_{t}*\tilde{C}_{t}+f_{t}*C_{t-1} \ \ \ \ \ \ \ \ (11) \\ o_{t}&=\sigma (W_{o}[h_{t-1}; x_{t}]+b_{o}) \ \ \ \ \ \ (12) \\ h_{t}&=o_{t}*tanh(C_{t}) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (13) \\ \end{aligned}

















i











t


























f











t


































C







~
















t


























C











t


























o











t


























h











t














































=




σ


(



W











i



















[



h











t





1



















;





x











t



















]




+





b











i



















)
















(


8


)












=




σ


(



W











f



















[



h











t





1



















;





x











t



















]




+





b











f



















)














(


9


)












=




t


a


n


h


(



W











c



















[



h











t





1



















;





x











t



















]




+





b











c



















)






(


1


0


)












=





i











t



































C







~
















t





















+





f











t



























C











t





1



































(


1


1


)












=




σ


(



W











o



















[



h











t





1



















;





x











t



















]




+





b











o



















)














(


1


2


)












=





o











t


























t


a


n


h


(



C











t



















)




































(


1


3


)
























类比RNN中偏导的连乘部分,LSTM中连乘部分为





C

t

C

t

1

=

f

t

=

σ

         

(

14

)

\frac{\partial C_{t}}{\partial C_{t-1}}=f_{t}=\sigma \ \ \ \ \ \ \ \ \ (14)





















C











t





1



































C











t







































=









f











t





















=








σ




















(


1


4


)







对比公式(7)和公式(14),LSTM的连乘部分变成了σ,在实际参数更新过程中,通过控制其值接近于1,则经过多次连乘(训练)后,梯度也不会消失;而σ的值不会大于1,故不会出现梯度爆炸。



版权声明:本文为tianyunzqs原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。