LSTM参数计算

  • Post author:
  • Post category:其他




LSTM结构

在这里插入图片描述



输入输出结构

在n时刻


LSTM的输入有三个:


1.当前时刻网络的输入值Xt;

2.上一时刻LSTM的输出值ht-1;

3.上一时刻的单元状态Ct-1。


LSTM的输出有两个:


1.当前时刻LSTM输出值ht;

2.当前时刻的单元状态Ct。



门结构

门有三类:忘记门,输入门,输出门。

1.忘记门:以一定的概率控制是否遗忘上一层的隐藏细胞状态;

在这里插入图片描述

2.输入门:负责处理当前序列位置的输入,更新细胞状态;

在这里插入图片描述

在这里插入图片描述

3.输出门:决定输出什么。

在这里插入图片描述



参数计算

假设输入为:

TIME_STEPS = 28		# 时间步
INPUT_SIZE = 28		# 每个时间步的特征长度m
CELL_SIZE = 100		# 隐藏神经元个数n
OUTPUT_SIZE = 10	# 输出长度
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])

LSTM:

x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE), return_sequences=False)(inputs)

输出:

x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)

网络参数:

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 28, 28)            0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 100)               51600     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1010      
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
=================================================================
Total params: 52,610
Trainable params: 52,610
Non-trainable params: 0

其中:

总参数 = LSTM层 + Dense层 = 51600 + 1010 = 52610

LSTM层:51600 = 4 x ((input_size+cell_size) x cell_size+cell_size) = 4x((28+100)x100+100)

Dense层:1010 = 100 x 10 + 10



版权声明:本文为bblingbbling原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。