BM25算法,python实现(源代码)

  • Post author:
  • Post category:python




BM25算法,python实现

直接上代码吧,公式在维基百科上搜一下。有帮助的话就点赞收藏一下吧,有问题直接评论,会进行答复。

import json
import math
import os
import pickle
import sys
from typing import Dict, List


class BM25:

    EPSILON = 0.25
    PARAM_K1 = 1.5 # BM25算法中超参数
    PARAM_B = 0.6  # BM25算法中超参数

    def __init__(self, config: str, corpus: Dict):
        """
            初始化BM25模型,可以从参数config中加载,也可以从corpus中获取,参数config和corpus只能有一个不为空
            :param corpus: 文档集, 文档集合应该是字典形式,key为文档的唯一标识,val对应其文本内容
            :param config: 若config不为空,则从其中加载对应的参数
        """
        if (config and corpus) or (not config and not corpus):
            raise ValueError("config 和 corpus 不能同时为空 或 同时不为空")

        self.corpus_size = 0            # 文档数量
        self.wordNumsOfAllDoc = 0       # 用于计算文档集合中平均每篇文档的词数 -> wordNumsOfAllDoc / corpus_size
        self.doc_freqs = {}             # 记录每篇文档中查询词的词频
        self.idf = {}                   # 记录查询词的 IDF
        self.doc_len = {}               # 记录每篇文档的单词数
        self.docContainedWord = {}      # 包含单词 word 的文档集合

        if config:
            self._load_config(config)
        else:
            self._initialize(corpus)
            self._save_config()

    def _load_config(self, config: Dict):
        with open(config, 'rb') as f:
            config_data = pickle.load(f)
        self.corpus_size = config_data.get("corpus_size")
        self.wordNumsOfAllDoc = config_data.get("wordNumsOfAllDoc")
        self.doc_freqs = config_data.get("doc_freqs")
        self.idf = config_data.get("idf")
        self.doc_len = config_data.get("doc_len")
        self.docContainedWord = config_data.get("docContainedWord")

    def _initialize(self, corpus: Dict):
        """
            根据语料库构建倒排索引
        """
        # nd = {} # word -> number of documents containing the word
        for index, document in corpus.items():
            self.corpus_size += 1
            self.doc_len[index] = len(document)  # 文档的单词数
            self.wordNumsOfAllDoc += len(document)

            frequencies = {}  # 一篇文档中单词出现的频率
            for word in document:
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1
            self.doc_freqs[index] = frequencies

            # 构建词到文档的倒排索引,将包含单词的和文档和包含关系进行反向映射
            for word in frequencies.keys():
                if word not in self.docContainedWord:
                    self.docContainedWord[word] = set()
                self.docContainedWord[word].add(index)

        # 计算 idf
        idf_sum = 0  # collect idf sum to calculate an average idf for epsilon value
        negative_idfs = []
        for word in self.docContainedWord.keys():
            doc_nums_contained_word = len(self.docContainedWord[word])
            idf = math.log(self.corpus_size - doc_nums_contained_word +
                           0.5) - math.log(doc_nums_contained_word + 0.5)
            self.idf[word] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(word)

        average_idf = float(idf_sum) / len(self.idf)
        eps = BM25.EPSILON * average_idf
        for word in negative_idfs:
            self.idf[word] = eps

    @property
    def avgdl(self):
        return float(self.wordNumsOfAllDoc) / self.corpus_size

    def _save_config(self):
        path = os.path.join(sys.path[0], "BM25_config.pickle")
        with open(path, 'wb') as f:
            json.dump(self.__dict__, f)

    def get_score(self, query: List, doc_index):
        """
        计算查询 q 和文档 d 的相关性分数
        :param query: 查询词列表
        :param doc_index: 为语料库中某篇文档对应的索引
        """
        k1 = BM25.PARAM_K1
        b = BM25.PARAM_B
        score = 0
        doc_freqs = self.doc_freqs[doc_index]
        for word in query.values():
            if word not in doc_freqs:
                continue
            score += self.idf[word] * doc_freqs[word] * (k1 + 1) / (
                doc_freqs[word] + k1 * (1 - b + b * self.doc_len[doc_index] / self.avgdl))
        return [doc_index, score]

    def get_scores(self, query):
        scores = [self.get_score(query, index) for index in self.doc_len.keys()]
        return scores
    



版权声明:本文为weixin_42655901原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。