tensor二维矩阵计算相似度

  • Post author:
  • Post category:其他


注意:计算相似度时必须保证两个矩阵维度相同,否则报错

import torch
from transformers import BertConfig, BertModel, BertTokenizer


def bert_output(texts, name):
    tokens, segments, input_masks =[], [], []
    for text in texts:
        tokenized_text = tokenizer.tokenize(text)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens.append(indexed_tokens)
        segments.append( [0]*len(indexed_tokens) )
        input_masks.append( [1]*len(indexed_tokens) )

    max_len = max([len(single) for single in tokens])  # 最大的句子长度

    for j in range(len(tokens)):
        padding = [0] * (max_len - len(tokens[j]))
        tokens[j] += padding
        segments[j] += padding
        input_masks[j] += padding

    # device = torch.cuda.current_device()

    tokens_tensor = torch.tensor(tokens)
    segments_tensors = torch.tensor(segments)
    input_masks_tensors = torch.tensor(input_masks)

    # output = model(tokens_tensor)
    output = model(tokens_tensor, segments_tensors, input_masks_tensors)
    sequence_output = output[0]
    pooled_output = output[1] # CLS
    torch.set_printoptions(edgeitems=30)

    # with open(name, 'a', encoding='utf-8') as f:
    #     # f.write("sequence_output:")
    #     f.write(str(sequence_output))
    #     # f.write('\n')
    #     # f.write("pooled_output:")
    #     # f.write(str(pooled_output))
    return sequence_output#输出slot

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased')
    model_config = BertConfig.from_pretrained('./bert-base-uncased')
    model = BertModel.from_pretrained('./bert-base-uncased',config=model_config)

    # texts_atis = ["[CLS] i want to fly from baltimore to dallas round trip [SEP]",
    #          "[CLS] how can i find that out [SEP]",
    #          "[CLS] how many flights does twa have in business class [SEP]"]
    texts_atis =  ["[CLS] i want to fly from baltimore to dallas round trip [SEP]"]
    texts_snips = ["[CLS] what the weather in my current spot the [SEP]",
                  "[CLS] what the weather like in the city frewen [SEP]",
                  "[CLS] what the weather supposed to be like today [SEP]"]#整个文件

    atis = 'atis.txt'
    snips = 'snips.txt'
    atis_out = bert_output(texts_atis, atis)#bert输出向量atis
    snips_out = bert_output(texts_snips, snips)#bert输出向量snips

    for text in atis_out:#text是二重矩阵
        atis_2 = text
        # print(list(atis_2.size()))
    for text in snips_out:#text是二重矩阵

        output = torch.cosine_similarity(atis_2, text, dim=1)
        print(output)
        # print(list(text.size()))
        



版权声明:本文为tailonh原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。