BM25算法,python实现
直接上代码吧,公式在维基百科上搜一下。有帮助的话就点赞收藏一下吧,有问题直接评论,会进行答复。
import json
import math
import os
import pickle
import sys
from typing import Dict, List
class BM25:
EPSILON = 0.25
PARAM_K1 = 1.5 # BM25算法中超参数
PARAM_B = 0.6 # BM25算法中超参数
def __init__(self, config: str, corpus: Dict):
"""
初始化BM25模型,可以从参数config中加载,也可以从corpus中获取,参数config和corpus只能有一个不为空
:param corpus: 文档集, 文档集合应该是字典形式,key为文档的唯一标识,val对应其文本内容
:param config: 若config不为空,则从其中加载对应的参数
"""
if (config and corpus) or (not config and not corpus):
raise ValueError("config 和 corpus 不能同时为空 或 同时不为空")
self.corpus_size = 0 # 文档数量
self.wordNumsOfAllDoc = 0 # 用于计算文档集合中平均每篇文档的词数 -> wordNumsOfAllDoc / corpus_size
self.doc_freqs = {} # 记录每篇文档中查询词的词频
self.idf = {} # 记录查询词的 IDF
self.doc_len = {} # 记录每篇文档的单词数
self.docContainedWord = {} # 包含单词 word 的文档集合
if config:
self._load_config(config)
else:
self._initialize(corpus)
self._save_config()
def _load_config(self, config: Dict):
with open(config, 'rb') as f:
config_data = pickle.load(f)
self.corpus_size = config_data.get("corpus_size")
self.wordNumsOfAllDoc = config_data.get("wordNumsOfAllDoc")
self.doc_freqs = config_data.get("doc_freqs")
self.idf = config_data.get("idf")
self.doc_len = config_data.get("doc_len")
self.docContainedWord = config_data.get("docContainedWord")
def _initialize(self, corpus: Dict):
"""
根据语料库构建倒排索引
"""
# nd = {} # word -> number of documents containing the word
for index, document in corpus.items():
self.corpus_size += 1
self.doc_len[index] = len(document) # 文档的单词数
self.wordNumsOfAllDoc += len(document)
frequencies = {} # 一篇文档中单词出现的频率
for word in document:
if word not in frequencies:
frequencies[word] = 0
frequencies[word] += 1
self.doc_freqs[index] = frequencies
# 构建词到文档的倒排索引,将包含单词的和文档和包含关系进行反向映射
for word in frequencies.keys():
if word not in self.docContainedWord:
self.docContainedWord[word] = set()
self.docContainedWord[word].add(index)
# 计算 idf
idf_sum = 0 # collect idf sum to calculate an average idf for epsilon value
negative_idfs = []
for word in self.docContainedWord.keys():
doc_nums_contained_word = len(self.docContainedWord[word])
idf = math.log(self.corpus_size - doc_nums_contained_word +
0.5) - math.log(doc_nums_contained_word + 0.5)
self.idf[word] = idf
idf_sum += idf
if idf < 0:
negative_idfs.append(word)
average_idf = float(idf_sum) / len(self.idf)
eps = BM25.EPSILON * average_idf
for word in negative_idfs:
self.idf[word] = eps
@property
def avgdl(self):
return float(self.wordNumsOfAllDoc) / self.corpus_size
def _save_config(self):
path = os.path.join(sys.path[0], "BM25_config.pickle")
with open(path, 'wb') as f:
json.dump(self.__dict__, f)
def get_score(self, query: List, doc_index):
"""
计算查询 q 和文档 d 的相关性分数
:param query: 查询词列表
:param doc_index: 为语料库中某篇文档对应的索引
"""
k1 = BM25.PARAM_K1
b = BM25.PARAM_B
score = 0
doc_freqs = self.doc_freqs[doc_index]
for word in query.values():
if word not in doc_freqs:
continue
score += self.idf[word] * doc_freqs[word] * (k1 + 1) / (
doc_freqs[word] + k1 * (1 - b + b * self.doc_len[doc_index] / self.avgdl))
return [doc_index, score]
def get_scores(self, query):
scores = [self.get_score(query, index) for index in self.doc_len.keys()]
return scores
版权声明:本文为weixin_42655901原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。