【NLP实践】使用Pytorch进行文本分类——BILSTM+ATTENTION

  • Post author:
  • Post category:其他




网络结构

BILSTM+ATTENTION 网络结构



代码实现

class TextBILSTM(nn.Module):
    
    def __init__(self,
                 config:TRNNConfig,
                 char_size = 5000,
                 pinyin_size = 5000):
        super(TextBILSTM, self).__init__()
        self.num_classes = config.num_classes
        self.learning_rate = config.learning_rate
        self.keep_dropout = config.keep_dropout
        self.char_embedding_size = config.char_embedding_size
        self.pinyin_embedding_size = config.pinyin_embedding_size
        self.l2_reg_lambda = config.l2_reg_lambda
        self.hidden_dims = config.hidden_dims
        self.char_size = char_size
        self.pinyin_size = pinyin_size
        self.rnn_layers = config.rnn_layers

        self.build_model()


    def build_model(self):
        # 初始化字向量
        self.char_embeddings = nn.Embedding(self.char_size, self.char_embedding_size)
        # 字向量参与更新
        self.char_embeddings.weight.requires_grad = True
        # 初始化拼音向量
        self.pinyin_embeddings = nn.Embedding(self.pinyin_size, self.pinyin_embedding_size)
        self.pinyin_embeddings.weight.requires_grad = True
        # attention layer
        self.attention_layer = nn.Sequential(
            nn.Linear(self.hidden_dims, self.hidden_dims),
            nn.ReLU(inplace=True)
        )
        # self.attention_weights = self.attention_weights.view(self.hidden_dims, 1)

        # 双层lstm
        self.lstm_net = nn.LSTM(self.char_embedding_size, self.hidden_dims,
                                num_layers=self.rnn_layers, dropout=self.keep_dropout,
                                bidirectional=True)
        # FC层
        # self.fc_out = nn.Linear(self.hidden_dims, self.num_classes)
        self.fc_out = nn.Sequential(
            nn.Dropout(self.keep_dropout),
            nn.Linear(self.hidden_dims, self.hidden_dims),
            nn.ReLU(inplace=True),
            nn.Dropout(self.keep_dropout),
            nn.Linear(self.hidden_dims, self.num_classes)
        )

    def attention_net_with_w(self, lstm_out, lstm_hidden):
        '''

        :param lstm_out:    [batch_size, len_seq, n_hidden * 2]
        :param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden]
        :return: [batch_size, n_hidden]
        '''
        lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
        # h [batch_size, time_step, hidden_dims]
        h = lstm_tmp_out[0] + lstm_tmp_out[1]
        # [batch_size, num_layers * num_directions, n_hidden]
        lstm_hidden = torch.sum(lstm_hidden, dim=1)
        # [batch_size, 1, n_hidden]
        lstm_hidden = lstm_hidden.unsqueeze(1)
        # atten_w [batch_size, 1, hidden_dims]
        atten_w = self.attention_layer(lstm_hidden)
        # m [batch_size, time_step, hidden_dims]
        m = nn.Tanh()(h)
        # atten_context [batch_size, 1, time_step]
        atten_context = torch.bmm(atten_w, m.transpose(1, 2))
        # softmax_w [batch_size, 1, time_step]
        softmax_w = F.softmax(atten_context, dim=-1)
        # context [batch_size, 1, hidden_dims]
        context = torch.bmm(softmax_w, h)
        result = context.squeeze(1)
        return result

    def forward(self, char_id, pinyin_id):
        # char_id = torch.from_numpy(np.array(input[0])).long()
        # pinyin_id = torch.from_numpy(np.array(input[1])).long()

        sen_char_input = self.char_embeddings(char_id)
        sen_pinyin_input = self.pinyin_embeddings(pinyin_id)

        sen_input = torch.cat((sen_char_input, sen_pinyin_input), dim=1)
        # input : [len_seq, batch_size, embedding_dim]
        sen_input = sen_input.permute(1, 0, 2)
        output, (final_hidden_state, final_cell_state) = self.lstm_net(sen_input)
        # output : [batch_size, len_seq, n_hidden * 2]
        output = output.permute(1, 0, 2)
        # final_hidden_state : [batch_size, num_layers * num_directions, n_hidden]
        final_hidden_state = final_hidden_state.permute(1, 0, 2)
        # final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True)
        # atten_out = self.attention_net(output, final_hidden_state)
        atten_out = self.attention_net_with_w(output, final_hidden_state)
        return self.fc_out(atten_out)
        



Attention计算

  1. 将BILSTM网络输出的结果(shape:[batch_size, time_step, hidden_dims * num_directions(=2)])拆成两个大小为[batch_size, time_step, hidden_dims]的Tensor;
  2. 将第一步拆出的两个Tensor进行相加运算得到

    h

    (shape:[batch_size, time_step, hidden_dims]);
  3. 将BILSTM网络最后一个隐状态(shape:[batch_size, num_layers * num_directions, hidden_dims])在第二维度进行求和,得到新的

    lstm_hidden

    (shape:[batch_size, hidden_dims]);


  4. lstm_hidden

    的维度从[batch_size, n_hidden]扩展到[batch_size, 1, hidden_dims];
  5. 使用

    slef.atten_layer(h)

    获得用于后续计算权重的向量

    atten_w

    (shape:[batch_size, 1, hidden_dims]);


  6. h

    进行

    tanh

    激活,得到

    m

    (shape:[batch_size, time_step, hidden_dims]);
  7. 使用

    torch.bmm(atten_w, m.transpose(1, 2))

    得到

    atten_context

    (shape:[batch_size, 1, time_step]);


  8. atten_context

    使用

    F.softmax(atten_context, dim=-1)

    进行归一化,得到基于上下文权重的

    softmax_w

    (shape:[batch_size, 1, time_step]);
  9. 使用

    torch.bmm(softmax_w, h)

    得到基于权重的BILSTM输出

    context

    (shape:[batch_size, 1, hidden_dims]);


  10. context

    的第二维度消掉,得到

    result

    (shape:[batch_size, hidden_dims]) ;
  11. 返回

    result



模型效果

  • 1层BILSTM在训练集准确率:

    99.8%

    ,测试集准确率:

    96.5%

  • 2层BILSTM在训练集准确率:

    99.9%

    ,测试集准确率:

    97.3%



调参

  • dropout的值要在

    0.1

    以下(经验之谈,笔者在实践中发现,dropout取0.1时比dropout取0.3时在测试集准确率能提高0.5%)。



相关文章

  • 使用TextCNN模型进行文本分类(

    链接

    );
  • 使用Transformer模型进行文本分类(

    链接



版权声明:本文为dendi_hust原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。