贝叶斯分类器(Python实现+详细完整源码和原理)

  • Post author:
  • Post category:python


在概率和统计学领域,贝叶斯理论基于对某一事件证据的认识来预测该事件的发生概率,

由结果推测原因的概率大小

首先,理解这个公式的前提是理解条件概率,因此先复习条件概率。

P(A|B)=P(AB)/P(B)

贝叶斯公式:

在机器学习领域,贝叶斯分类器是基于贝叶斯理论并假设各特征相互独立的分类方法,

基本方法是:使用特征向量来表征某个实体,并在该实体上绑定一个标签来代表其所属的类别。

优点:只需要极少数的训练数据,就可以建立起分类所需要的所有参数

抽象而言就是:贝叶斯分类器就是条件概率:给定一个实体,求解这个实体属于某一类的概率,这个实体用

一个长度为n的向量来表示,向量中的每一个元素表示相互独立的特征值的量。

=====================================================================================

以下是对水果分类的python代码实现:

类别 较长 不长 不甜 黄色 不是黄色 总数
香蕉 400 100 350 150 450 50 500
橘子 0 300 150 150 300 0 300
其他水果 100 100 150 50 50 150 200
总数 500 500 650 350 800 200 1000

python文件结构:都在一个包下(Bayes)

bayes_classfier.py


  
  
  1. #!/usr/bin/env python
  2. # encoding: utf-8
  3. """
  4. @Company:华中科技大学电气学院聚变与等离子研究所
  5. @version: V1.0
  6. @author: YEXIN
  7. @contact: 1650996069@qq.com 2018--2020
  8. @software: PyCharm
  9. @file: bayes_classfier.py
  10. @time: 2018/8/16 16:49
  11. @Desc:贝叶斯分类器
  12. """
  13. ###贝叶斯分类器源码
  14. ####训练数据集---->合适参数
  15. datasets = { 'banala':{ 'long': 400, 'not_long': 100, 'sweet': 350, 'not_sweet': 150, 'yellow': 450, 'not_yellow': 50},
  16. 'orange':{ 'long': 0, 'not_long': 300, 'sweet': 150, 'not_sweet': 150, 'yellow': 300, 'not_yellow': 0},
  17. 'other_fruit':{ 'long': 100, 'not_long': 100, 'sweet': 150, 'not_sweet': 50, 'yellow': 50, 'not_yellow': 150}
  18. }
  19. def count_total(data):
  20. '''计算各种水果的总数
  21. return {‘banala’:500 ...}'''
  22. count = {}
  23. total = 0
  24. for fruit in data:
  25. '''因为水果要么甜要么不甜,可以用 这两种特征来统计总数'''
  26. count[fruit] = data[fruit][ 'sweet'] + data[fruit][ 'not_sweet']
  27. total += count[fruit]
  28. return count,total
  29. #categories,simpleTotal = count_total(datasets)
  30. #print(categories,simpleTotal)
  31. ###########################################################
  32. def cal_base_rates(data):
  33. '''计算各种水果的先验概率
  34. return {‘banala’:0.5 ...}'''
  35. categories,total = count_total(data)
  36. cal_base_rates = {}
  37. for label in categories:
  38. priori_prob = categories[label]/total
  39. cal_base_rates[label] = priori_prob
  40. return cal_base_rates
  41. #Prio = cal_base_rates(datasets)
  42. #print(Prio)
  43. ############################################################
  44. def likelihold_prob(data):
  45. '''计算各个特征值在已知水果下的概率(likelihood probabilities)
  46. {'banala':{'long':0.8}...}'''
  47. count,_ = count_total(data)
  48. likelihold = {}
  49. for fruit in data:
  50. '''创建一个临时的字典,临时存储各个特征值的概率'''
  51. attr_prob = {}
  52. for attr in data[fruit]:
  53. #计算各个特征值在已知水果下的概率
  54. attr_prob[attr] = data[fruit][attr]/count[fruit]
  55. likelihold[fruit] = attr_prob
  56. return likelihold
  57. #LikeHold = likelihold_prob(datasets)
  58. #print(LikeHold)
  59. ############################################################
  60. def evidence_prob(data):
  61. '''计算特征的概率对分类结果的影响
  62. return {'long':50%...}'''
  63. #水果的所有特征
  64. attrs = list(data[ 'banala'].keys())
  65. count,total = count_total(data)
  66. evidence_prob = {}
  67. #计算各种特征的概率
  68. for attr in attrs:
  69. attr_total = 0
  70. for fruit in data:
  71. attr_total += data[fruit][attr]
  72. evidence_prob[attr] = attr_total/total
  73. return evidence_prob
  74. #Evidence_prob = evidence_prob(datasets)
  75. #print(Evidence_prob)
  76. ##########################################################
  77. #以上是训练数据用到的函数,即将数据转化为代码计算概率
  78. ##########################################################
  79. class navie_bayes_classifier:
  80. '''初始化贝叶斯分类器,实例化时会调用__init__函数'''
  81. def __init__(self,data=datasets):
  82. self._data = datasets
  83. self._labels = [key for key in self._data.keys()]
  84. self._priori_prob = cal_base_rates(self._data)
  85. self._likelihold_prob = likelihold_prob(self._data)
  86. self._evidence_prob = evidence_prob(self._data)
  87. #下面的函数可以直接调用上面类中定义的变量
  88. def get_label(self,length,sweetness,color):
  89. '''获取某一组特征值的类别'''
  90. self._attrs = [length,sweetness,color]
  91. res = {}
  92. for label in self._labels:
  93. prob = self._priori_prob[label] #取某水果占比率
  94. #print("各个水果的占比率:",prob)
  95. for attr in self._attrs:
  96. #单个水果的某个特征概率除以总的某个特征概率 再乘以某水果占比率
  97. prob*=self._likelihold_prob[label][attr]/self._evidence_prob[attr]
  98. #print(prob)
  99. res[label] = prob
  100. #print(res)
  101. return res
  102. ============================================================================================

generate_attires.py


  
  
  1. #!/usr/bin/env python
  2. # encoding: utf-8
  3. """
  4. @Company:华中科技大学电气学院聚变与等离子研究所
  5. @version: V1.0
  6. @author: YEXIN
  7. @contact: 1650996069@qq.com 2018--2020
  8. @software: PyCharm
  9. @file: generate_attires.py
  10. @time: 2018/8/17 13:43
  11. @Desc:产生测试数据集来测试贝叶斯分类器的预测能力
  12. """
  13. import random
  14. def random_attr(pair):
  15. #生成0-1之间的随机数
  16. return pair[random.randint( 0, 1)]
  17. def gen_attrs():
  18. #特征值的取值集合
  19. sets = [( 'long', 'not_long'),( 'sweet', 'not_sweet'),( 'yellow', 'not_yellow')]
  20. test_datasets = []
  21. for i in range( 20):
  22. #使用map函数来生成一组特征值
  23. test_datasets.append(list(map(random_attr,sets)))
  24. return test_datasets
  25. #print(gen_attrs())

======================================================================================

classfication.py


  
  
  1. #!/usr/bin/env python
  2. # encoding: utf-8
  3. """
  4. @Company:华中科技大学电气学院聚变与等离子研究所
  5. @version: V1.0
  6. @author: YEXIN
  7. @contact: 1650996069@qq.com 2018--2020
  8. @software: PyCharm
  9. @file: classfication.py
  10. @time: 2018/8/17 13:55
  11. @Desc:使用贝叶斯分类器对测试结果进行分类
  12. """
  13. import operator
  14. import bayes_classfier
  15. import generate_attires
  16. def main():
  17. test_datasets = generate_attires.gen_attrs()
  18. classfier = bayes_classfier.navie_bayes_classifier()
  19. for data in test_datasets:
  20. print( "特征值:",end= '\t')
  21. print(data)
  22. print( "预测结果:", end= '\t')
  23. res=classfier.get_label(*data) #表示多参传入
  24. print(res) #预测属于哪种水果的概率
  25. print( '水果类别:',end= '\t')
  26. #对后验概率排序,输出概率最大的标签
  27. print(sorted(res.items(),key=operator.itemgetter( 1),reverse= True)[ 0][ 0])
  28. if __name__ == '__main__':
  29. #表示模块既可以被导入(到 Python shell 或者其他模块中),也可以作为脚本来执行。
  30. #当模块被导入时,模块名称是文件名;而当模块作为脚本独立运行时,名称为 __main__。
  31. #让模块既可以导入又可以执行
  32. main()

=====================================================================================

结果展示:

特征值:    [‘not_long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.08571428571428573, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘not_long’, ‘sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.1076923076923077, ‘orange’: 0.0, ‘other_fruit’: 0.8653846153846153}

水果类别:    other_fruit

特征值:    [‘not_long’, ‘sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.24230769230769234, ‘orange’: 0.5769230769230769, ‘other_fruit’: 0.07211538461538461}

水果类别:    orange

特征值:    [‘not_long’, ‘sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.24230769230769234, ‘orange’: 0.5769230769230769, ‘other_fruit’: 0.07211538461538461}

水果类别:    orange

特征值:    [‘not_long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.08571428571428573, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.3428571428571429, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.3428571428571429, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.3428571428571429, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.3428571428571429, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.7714285714285716, ‘orange’: 0.0, ‘other_fruit’: 0.04464285714285715}

水果类别:    banala

特征值:    [‘not_long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.1928571428571429, ‘orange’: 1.0714285714285714, ‘other_fruit’: 0.04464285714285715}

水果类别:    orange

特征值:    [‘not_long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.1928571428571429, ‘orange’: 1.0714285714285714, ‘other_fruit’: 0.04464285714285715}

水果类别:    orange

特征值:    [‘long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.3428571428571429, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘not_long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.1928571428571429, ‘orange’: 1.0714285714285714, ‘other_fruit’: 0.04464285714285715}

水果类别:    orange

特征值:    [‘not_long’, ‘sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.1076923076923077, ‘orange’: 0.0, ‘other_fruit’: 0.8653846153846153}

水果类别:    other_fruit

特征值:    [‘long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.7714285714285716, ‘orange’: 0.0, ‘other_fruit’: 0.04464285714285715}

水果类别:    banala

特征值:    [‘not_long’, ‘sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.24230769230769234, ‘orange’: 0.5769230769230769, ‘other_fruit’: 0.07211538461538461}

水果类别:    orange

特征值:    [‘long’, ‘not_sweet’, ‘not_yellow’]

预测结果:    {‘banala’: 0.3428571428571429, ‘orange’: 0.0, ‘other_fruit’: 0.5357142857142858}

水果类别:    other_fruit

特征值:    [‘not_long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.1928571428571429, ‘orange’: 1.0714285714285714, ‘other_fruit’: 0.04464285714285715}

水果类别:    orange

特征值:    [‘long’, ‘not_sweet’, ‘yellow’]

预测结果:    {‘banala’: 0.7714285714285716, ‘orange’: 0.0, ‘other_fruit’: 0.04464285714285715}

水果类别:    banala