LSTM结构
输入输出结构
在n时刻
LSTM的输入有三个:
1.当前时刻网络的输入值Xt;
2.上一时刻LSTM的输出值ht-1;
3.上一时刻的单元状态Ct-1。
LSTM的输出有两个:
1.当前时刻LSTM输出值ht;
2.当前时刻的单元状态Ct。
门结构
门有三类:忘记门,输入门,输出门。
1.忘记门:以一定的概率控制是否遗忘上一层的隐藏细胞状态;
2.输入门:负责处理当前序列位置的输入,更新细胞状态;
3.输出门:决定输出什么。
参数计算
假设输入为:
TIME_STEPS = 28 # 时间步
INPUT_SIZE = 28 # 每个时间步的特征长度m
CELL_SIZE = 100 # 隐藏神经元个数n
OUTPUT_SIZE = 10 # 输出长度
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
LSTM:
x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE), return_sequences=False)(inputs)
输出:
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
网络参数:
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28) 0
_________________________________________________________________
lstm_1 (LSTM) (None, 100) 51600
_________________________________________________________________
dense_1 (Dense) (None, 10) 1010
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 52,610
Trainable params: 52,610
Non-trainable params: 0
其中:
总参数 = LSTM层 + Dense层 = 51600 + 1010 = 52610
LSTM层:51600 = 4 x ((input_size+cell_size) x cell_size+cell_size) = 4x((28+100)x100+100)
Dense层:1010 = 100 x 10 + 10
版权声明:本文为bblingbbling原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。