注意:计算相似度时必须保证两个矩阵维度相同,否则报错
import torch
from transformers import BertConfig, BertModel, BertTokenizer
def bert_output(texts, name):
tokens, segments, input_masks =[], [], []
for text in texts:
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens.append(indexed_tokens)
segments.append( [0]*len(indexed_tokens) )
input_masks.append( [1]*len(indexed_tokens) )
max_len = max([len(single) for single in tokens]) # 最大的句子长度
for j in range(len(tokens)):
padding = [0] * (max_len - len(tokens[j]))
tokens[j] += padding
segments[j] += padding
input_masks[j] += padding
# device = torch.cuda.current_device()
tokens_tensor = torch.tensor(tokens)
segments_tensors = torch.tensor(segments)
input_masks_tensors = torch.tensor(input_masks)
# output = model(tokens_tensor)
output = model(tokens_tensor, segments_tensors, input_masks_tensors)
sequence_output = output[0]
pooled_output = output[1] # CLS
torch.set_printoptions(edgeitems=30)
# with open(name, 'a', encoding='utf-8') as f:
# # f.write("sequence_output:")
# f.write(str(sequence_output))
# # f.write('\n')
# # f.write("pooled_output:")
# # f.write(str(pooled_output))
return sequence_output#输出slot
if __name__ == '__main__':
tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased')
model_config = BertConfig.from_pretrained('./bert-base-uncased')
model = BertModel.from_pretrained('./bert-base-uncased',config=model_config)
# texts_atis = ["[CLS] i want to fly from baltimore to dallas round trip [SEP]",
# "[CLS] how can i find that out [SEP]",
# "[CLS] how many flights does twa have in business class [SEP]"]
texts_atis = ["[CLS] i want to fly from baltimore to dallas round trip [SEP]"]
texts_snips = ["[CLS] what the weather in my current spot the [SEP]",
"[CLS] what the weather like in the city frewen [SEP]",
"[CLS] what the weather supposed to be like today [SEP]"]#整个文件
atis = 'atis.txt'
snips = 'snips.txt'
atis_out = bert_output(texts_atis, atis)#bert输出向量atis
snips_out = bert_output(texts_snips, snips)#bert输出向量snips
for text in atis_out:#text是二重矩阵
atis_2 = text
# print(list(atis_2.size()))
for text in snips_out:#text是二重矩阵
output = torch.cosine_similarity(atis_2, text, dim=1)
print(output)
# print(list(text.size()))
版权声明:本文为tailonh原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。