网络结构
代码实现
class TextBILSTM(nn.Module):
def __init__(self,
config:TRNNConfig,
char_size = 5000,
pinyin_size = 5000):
super(TextBILSTM, self).__init__()
self.num_classes = config.num_classes
self.learning_rate = config.learning_rate
self.keep_dropout = config.keep_dropout
self.char_embedding_size = config.char_embedding_size
self.pinyin_embedding_size = config.pinyin_embedding_size
self.l2_reg_lambda = config.l2_reg_lambda
self.hidden_dims = config.hidden_dims
self.char_size = char_size
self.pinyin_size = pinyin_size
self.rnn_layers = config.rnn_layers
self.build_model()
def build_model(self):
# 初始化字向量
self.char_embeddings = nn.Embedding(self.char_size, self.char_embedding_size)
# 字向量参与更新
self.char_embeddings.weight.requires_grad = True
# 初始化拼音向量
self.pinyin_embeddings = nn.Embedding(self.pinyin_size, self.pinyin_embedding_size)
self.pinyin_embeddings.weight.requires_grad = True
# attention layer
self.attention_layer = nn.Sequential(
nn.Linear(self.hidden_dims, self.hidden_dims),
nn.ReLU(inplace=True)
)
# self.attention_weights = self.attention_weights.view(self.hidden_dims, 1)
# 双层lstm
self.lstm_net = nn.LSTM(self.char_embedding_size, self.hidden_dims,
num_layers=self.rnn_layers, dropout=self.keep_dropout,
bidirectional=True)
# FC层
# self.fc_out = nn.Linear(self.hidden_dims, self.num_classes)
self.fc_out = nn.Sequential(
nn.Dropout(self.keep_dropout),
nn.Linear(self.hidden_dims, self.hidden_dims),
nn.ReLU(inplace=True),
nn.Dropout(self.keep_dropout),
nn.Linear(self.hidden_dims, self.num_classes)
)
def attention_net_with_w(self, lstm_out, lstm_hidden):
'''
:param lstm_out: [batch_size, len_seq, n_hidden * 2]
:param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden]
:return: [batch_size, n_hidden]
'''
lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
# h [batch_size, time_step, hidden_dims]
h = lstm_tmp_out[0] + lstm_tmp_out[1]
# [batch_size, num_layers * num_directions, n_hidden]
lstm_hidden = torch.sum(lstm_hidden, dim=1)
# [batch_size, 1, n_hidden]
lstm_hidden = lstm_hidden.unsqueeze(1)
# atten_w [batch_size, 1, hidden_dims]
atten_w = self.attention_layer(lstm_hidden)
# m [batch_size, time_step, hidden_dims]
m = nn.Tanh()(h)
# atten_context [batch_size, 1, time_step]
atten_context = torch.bmm(atten_w, m.transpose(1, 2))
# softmax_w [batch_size, 1, time_step]
softmax_w = F.softmax(atten_context, dim=-1)
# context [batch_size, 1, hidden_dims]
context = torch.bmm(softmax_w, h)
result = context.squeeze(1)
return result
def forward(self, char_id, pinyin_id):
# char_id = torch.from_numpy(np.array(input[0])).long()
# pinyin_id = torch.from_numpy(np.array(input[1])).long()
sen_char_input = self.char_embeddings(char_id)
sen_pinyin_input = self.pinyin_embeddings(pinyin_id)
sen_input = torch.cat((sen_char_input, sen_pinyin_input), dim=1)
# input : [len_seq, batch_size, embedding_dim]
sen_input = sen_input.permute(1, 0, 2)
output, (final_hidden_state, final_cell_state) = self.lstm_net(sen_input)
# output : [batch_size, len_seq, n_hidden * 2]
output = output.permute(1, 0, 2)
# final_hidden_state : [batch_size, num_layers * num_directions, n_hidden]
final_hidden_state = final_hidden_state.permute(1, 0, 2)
# final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True)
# atten_out = self.attention_net(output, final_hidden_state)
atten_out = self.attention_net_with_w(output, final_hidden_state)
return self.fc_out(atten_out)
Attention计算
- 将BILSTM网络输出的结果(shape:[batch_size, time_step, hidden_dims * num_directions(=2)])拆成两个大小为[batch_size, time_step, hidden_dims]的Tensor;
-
将第一步拆出的两个Tensor进行相加运算得到
h
(shape:[batch_size, time_step, hidden_dims]); -
将BILSTM网络最后一个隐状态(shape:[batch_size, num_layers * num_directions, hidden_dims])在第二维度进行求和,得到新的
lstm_hidden
(shape:[batch_size, hidden_dims]); -
将
lstm_hidden
的维度从[batch_size, n_hidden]扩展到[batch_size, 1, hidden_dims]; -
使用
slef.atten_layer(h)
获得用于后续计算权重的向量
atten_w
(shape:[batch_size, 1, hidden_dims]); -
将
h
进行
tanh
激活,得到
m
(shape:[batch_size, time_step, hidden_dims]); -
使用
torch.bmm(atten_w, m.transpose(1, 2))
得到
atten_context
(shape:[batch_size, 1, time_step]); -
将
atten_context
使用
F.softmax(atten_context, dim=-1)
进行归一化,得到基于上下文权重的
softmax_w
(shape:[batch_size, 1, time_step]); -
使用
torch.bmm(softmax_w, h)
得到基于权重的BILSTM输出
context
(shape:[batch_size, 1, hidden_dims]); -
将
context
的第二维度消掉,得到
result
(shape:[batch_size, hidden_dims]) ; -
返回
result
;
模型效果
-
1层BILSTM在训练集准确率:
99.8%
,测试集准确率:
96.5%
; -
2层BILSTM在训练集准确率:
99.9%
,测试集准确率:
97.3%
;
调参
-
dropout的值要在
0.1
以下(经验之谈,笔者在实践中发现,dropout取0.1时比dropout取0.3时在测试集准确率能提高0.5%)。
相关文章
版权声明:本文为dendi_hust原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。