朴素贝叶斯算法(NB)

  • Post author:
  • Post category:其他

算法分析:贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。目前研究较多的贝叶斯分类器主要有四种,分别是:Naive Bayes、TAN、BAN和GBN。这次使用NB算法来实现。

       实现步骤:

    1、找到一个已知分类的待分类项集合,这个集合叫做训练样本集。

      2、统计得到在各类别下各个特征属性的条件概率估计。即

      3、如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导:

      

      因为分母对于所有类别为常数,因为我们只要将分子最大化皆可。

      

 

【优化提升】

   计算各个划分的条件概率P(a|y)是朴素贝叶斯分类的关键性步骤,当特征属性为离散值时,只要很方便的统计训练样本中各个划分在每个类别中出现的频率即可用来估计P(a|y),这次给的数据由于属性是连续的非离散的,所以,采用了正态分布的原则来获取每一个属性下的值占每个类别的概率,取代了原来计数的方式。

       第二点是可能是tests数据集里面,有的属性可能在训练集里面本身就没有,这样会导致最终得到的概率结果为0,所以我这里添加了一个很小的参数来代替原来为0 的概率分项。

       得到的概率可能很小,为了防止数据向下溢出,使用了log函数。

 

【实验结果】

           此方法如果只使用上面的操作来处理的话,一般会得到0.62的正确系数。

 

【实验思考】

        NB算法的前提是提供的数据属性之间没有相关性,因为最终分类的标准是根据各个属性的值在各个划分上的概率的积来判断的,很明显本次给出的数据集的属性有最大最小和平均这种有很大相关度的属性,所以后续的提升空间是在相关属性这方面来处理,但是我之前手动去除了最大最小的属性,得到的结果才是0.61,所以后续有待继续研究,不过之前有同学分享说用随机去除固定个属性,最后可以得到0.64+,这方面有待于后续来继续探讨了。

package com.sysu.jerry;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Map;

import javax.management.AttributeList;

public class TestBayes {
	
	FileInputStream file;
	
	public ArrayList<String> readAttr(){
		//获取所有属性的列表
		ArrayList<String> AttrList = new ArrayList<String>();
		try{
			file = new FileInputStream("train_.txt");
			InputStreamReader isr = new InputStreamReader(file);
			BufferedReader bfr = new BufferedReader(isr);
			String s = "";
			String sSplit[] = new String[100];
			while((s = bfr.readLine()) != null)
			{
				sSplit = s.toString().trim().split("	");
				for(int j=0; j < sSplit.length - 1; j++)
				{
					AttrList.add(sSplit[j]);
					//System.out.println(sSplit[j]);
				}
				break;
			}
			file.close();
		}catch(FileNotFoundException e){
			e.printStackTrace();
		}catch(IOException e){
			e.printStackTrace();
		}
		
		return AttrList;
		
	}
	
	public ArrayList<ArrayList<String>> readData()
	{
		ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
		try{
			file = new FileInputStream("train_.txt");
			InputStreamReader isr = new InputStreamReader(file);
			BufferedReader bfr = new BufferedReader(isr);
			
			String s = bfr.readLine();
			String sSplit[] = new String[100];
			while((s = bfr.readLine()) != null)
			{
				sSplit = s.toString().trim().split("	");
				ArrayList<String> data = new ArrayList<String>();
				for(int j = 0;j < sSplit.length; j++)
				{
					data.add(sSplit[j]);
				}
				datas.add(data);
			}
			file.close();
			
		}catch(FileNotFoundException e){
			e.printStackTrace();
		}catch(IOException e){
			e.printStackTrace();
		}
		return datas;
	}
	
	
	public ArrayList<ArrayList<String>> readTest()
	{
		ArrayList<ArrayList<String>> tests = new ArrayList<ArrayList<String>>();
		try{
			file = new FileInputStream("test_.txt");
			InputStreamReader isr = new InputStreamReader(file);
			BufferedReader bfr = new BufferedReader(isr);
			
			String s = bfr.readLine();
			String sSplit[] = new String[100];
			while((s = bfr.readLine()) != null)
			{
				sSplit = s.toString().trim().split("	");
				ArrayList<String> test = new ArrayList<String>();
				for(int j = 0;j < sSplit.length; j++)
				{
					test.add(sSplit[j]);
				}
				tests.add(test);
			}
			file.close();
			
		}catch(FileNotFoundException e){
			e.printStackTrace();
		}catch(IOException e){
			e.printStackTrace();
		}
		return tests;
	}
	
	public static void main(String args[])
	{
		TestBayes testBayes = new TestBayes();
		ArrayList<String> attrList = null;
		attrList = testBayes.readAttr();
		
		ArrayList<ArrayList<String>> datas = null;
		datas = testBayes.readData();
		
		ArrayList<ArrayList<String>> tests = null;
		tests = testBayes.readTest();
		
		Bayes bayes = new Bayes(datas,attrList,tests);
		
		Map<String,Integer> classes = bayes.classOfDatas();
		ArrayList<Double> averageYes = bayes.getAverageYes(classes);
		ArrayList<Double> varianceYes = bayes.getVarianceYes(averageYes,classes);
		
		ArrayList<Double> averageNo = bayes.getAverageNo(classes);
		ArrayList<Double> varianceNo = bayes.getVarianceNo(averageNo,classes);
		
		bayes.printResult(averageYes, varianceYes, averageNo, varianceNo, tests, classes);
		
		/*
		ArrayList<Map<String,Integer>> classesAttrYes = bayes.classOfAttrYes();
		ArrayList<Map<String,Integer>> classesAttrNo = bayes.classOfAttrNo();
		ArrayList<Double> average = bayes.getAverage();
		ArrayList<Double> variance = bayes.getVariance(average);
		
		ArrayList<ArrayList<Double>> pTable = bayes.getP(average, variance);
		
		//bayes.print(pTable);
		ArrayList<ArrayList<Double>> pYESTable = bayes.getPYes(pTable, classes,classesAttrYes);
		ArrayList<ArrayList<Double>> pNOTable = bayes.getPNo(pTable, classes,classesAttrNo);
		//bayes.print(pYESTable);
		
		ArrayList<Map<String,Double>> attrListPYes = bayes.getAttrListP(pYESTable);
		ArrayList<Map<String,Double>> attrListPNo = bayes.getAttrListP(pNOTable);
		
		bayes.getResult(attrListPYes, attrListPNo, tests, classes);*/
		
		
		
	}

}

package com.sysu.jerry;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;


public class Bayes {
	
	private ArrayList<ArrayList<String>> datas = null;
	private ArrayList<ArrayList<String>> tests = null;
	private ArrayList<String> attrList = null;
	public static final double PI = 3.1415916;
	
	public Bayes(ArrayList<ArrayList<String>> datas,
			ArrayList<String> attrList,
			ArrayList<ArrayList<String>> tests)
	{
		this.datas = datas;
		this.attrList = attrList;
		this.tests = tests;
	}
	
	
	public Map<String,Integer> classOfDatas()
	{
		Map<String,Integer> classes = new HashMap<String,Integer>();
		String c = "";
		ArrayList<String> tuple = null;
		for(int i=0;i<datas.size();i++)
		{
			tuple = datas.get(i);
			c = tuple.get(tuple.size() - 1);
			if(classes.containsKey(c))
			{
				classes.put(c, classes.get(c) + 1);
			}else{
				classes.put(c, 1);
			}
			
		}	
		return classes;	
	}
	
	
	public ArrayList<Double> getAverageYes(Map<String,Integer> classes)
	{
		ArrayList<Double> averageList = new ArrayList<Double>();
		for(int i=0; i<datas.get(0).size()-1;i++)
		{
			double sum = 0.00000;
			for(int j=0;j<datas.size();j++)
			{
				String c = datas.get(j).get(datas.get(j).size()-1);				
				if(c.equals("1"))
				{
					String s = datas.get(j).get(i);
					sum += Double.parseDouble(s);
				}
			}
			sum /= classes.get("1");
			
			
			averageList.add(sum);
		}
		return averageList;
	}
	
	public ArrayList<Double> getAverageNo(Map<String,Integer> classes)
	{
		ArrayList<Double> averageList = new ArrayList<Double>();
		for(int i=0; i<datas.get(0).size()-1;i++)
		{
			double sum = 0.00000;
			for(int j=0;j<datas.size();j++)
			{
				String c = datas.get(j).get(datas.get(j).size()-1);				
				if(c.equals("0"))
				{
					String s = datas.get(j).get(i);
					sum += Double.parseDouble(s);
				}
			}
			sum /= classes.get("0");
			
			//System.out.println("averagen:  " + sum);
			averageList.add(sum);
		}
		return averageList;
	}
	
	public ArrayList<Double> getVarianceYes(ArrayList<Double> averageList,Map<String,Integer> classes)
	{
		ArrayList<Double> varianceList = new ArrayList<Double>();
		for(int i=0; i<datas.get(0).size()-1;i++)
		{
			double squareSum = 0.00000;
			for(int j=0;j<datas.size();j++)
			{
				String c = datas.get(j).get(datas.get(j).size()-1);				
				if(c.equals("1"))
				{
					String s = datas.get(j).get(i);
					squareSum += (Double.parseDouble(s) - averageList.get(i))
					      * (Double.parseDouble(s) - averageList.get(i));
				}			
			}
			squareSum /= classes.get("1");
			varianceList.add(squareSum);		
		}
		
		return varianceList;
	}
	
	
	public ArrayList<Double> getVarianceNo(ArrayList<Double> averageList,Map<String,Integer> classes)
	{
		ArrayList<Double> varianceList = new ArrayList<Double>();
		for(int i=0; i<datas.get(0).size()-1;i++)
		{
			double squareSum = 0.00000;
			for(int j=0;j<datas.size();j++)
			{
				String c = datas.get(j).get(datas.get(j).size()-1);				
				if(c.equals("0"))
				{
					String s = datas.get(j).get(i);
					squareSum += (Double.parseDouble(s) - averageList.get(i))
					      * (Double.parseDouble(s) - averageList.get(i));
				}			
			}
			squareSum /= classes.get("0");
			varianceList.add(squareSum);		
		}
		
		return varianceList;
	}
	
	
	
	
	
	public void printResult(ArrayList<Double> averageYes,
			ArrayList<Double> varianceYes,
			ArrayList<Double> averageNo,
			ArrayList<Double> varianceNo,
			ArrayList<ArrayList<String>> tests,
			Map<String,Integer> classes)
	{
		
		int count = 1;
		int y = classes.get("1");
		int n = classes.get("0");	
		double yy = y*1.0/(y+n);
		double nn = n*1.0/(y+n);
	
		
		File file = new File("result_.txt");
		FileOutputStream out;
		try {
			out = new FileOutputStream(file);
			byte[] b = new byte[2014];
			for(int i=0;i<tests.size();i++)
			{
				double pYes = 1.0;
				double pNo = 1.0;
				double LOGpY = 0.0;
				double LOGpN = 0.0;
				for(int j=0;j<tests.get(i).size() ;
						j++)
				{
					double number = Double.parseDouble(tests.get(i).get(j));
					double pY = 0.0;
					double pN = 0.0;
					if(varianceYes.get(j) == 0)
					{
						if(number == averageYes.get(j))
						{
							pY = 1.0;
						}else{
							pY = 0.000000001;//添加平滑参数
						}
					}else{
						double coeffcientYes = 1 / Math.sqrt(2 * PI * varianceYes.get(j));
						
						double exponentialYes = - (number - averageYes.get(j)) 
								* (number - averageYes.get(j)) / (2 * varianceYes.get(j));
						pY = coeffcientYes * Math.exp(exponentialYes); //获取高斯分布概率	
					}

					LOGpY += Math.log(pY);  //防止数据下溢,用log
					
					//pYes *= pY;
					//System.out.println("LOGpY:  " + LOGpY);
					//System.out.println("pY:  " + pY);
					//System.out.println("pYes:  " + pYes);
					
					if(varianceNo.get(j) == 0)
					{
						if(number == averageNo.get(j))
						{
							pN = 1.0;
						}else{
							pN = 0.000000001;
						}
					}else{
						double coeffcientNo = 1 / Math.sqrt(2 * PI * varianceNo.get(j));
						double exponentialNo = - (number - averageNo.get(j)) 
								* (number - averageNo.get(j)) / (2 * varianceNo.get(j));
						pN = coeffcientNo * Math.exp(exponentialNo);		
					}
					
					//System.out.println("averageNo.get(j):  " + averageNo.get(j));
					//System.out.println("varianceNo.get(j):  " + varianceNo.get(j));
					
					
					LOGpN += Math.log(pN);
					
					//pNo *= pN;
					//System.out.println("LOGpN:  " + LOGpN);
					//System.out.println("pN:  " + pN);
					//System.out.println("pNo:  " + pNo);
					
				}
				
			/*	System.out.println("LOGpY:  " + LOGpY);
				System.out.println("LOGpN:  " + LOGpN);
				
				System.out.println("yy  " + yy);
				System.out.println("nn  " + nn);
				
				System.out.println("Math.log(yy):  " + Math.log(yy));
				System.out.println("Math.log(nn)  " + Math.log(nn));
				System.out.println("Math.log(XX)  " + Math.log(0.5));*/
		
				pYes = LOGpY + Math.log(yy);
				pNo = LOGpN + Math.log(nn);
				//pYes *= yy;
				//pNo *= nn;
				//System.out.println("pYes:  " + pYes);
				//System.out.println("pNo:  " + pNo);

				
				String str = "";
				if(pYes >= pNo)
				{
					System.out.println(count++ + ":" + "1");
					str = "1\r\n";
				}else{
					System.out.println(count++ + ":" + "0");
					str = "0\r\n";
				}
				
				 b = str.getBytes();
			     out.write(b);    
			
				
			}
			out.flush();
			out.close();
	
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}catch(IOException e){
			e.printStackTrace();
		}
		
	}
}	
	
	
	
	
	
	
	
	
	
	
	
	
	/*public ArrayList<Map<String,Integer>> classOfAttrYes()
	{
		ArrayList<Map<String,Integer>> classList = new ArrayList<Map<String,Integer>>();
		
		for(int i=0;i<datas.get(0).size() - 1;i++)
		{
			String c = "";
			Map<String,Integer> classes = new HashMap<String,Integer>();
			for(int j=0;j<datas.size();j++)
			{
				String s = datas.get(j).get(datas.get(0).size()-1);
				if(s.equals("1"))
				{
					c = datas.get(j).get(i);
					if(classes.containsKey(c))
					{
						classes.put(c, classes.get(c) + 1);
					}else{
						classes.put(c, 1);
					}
				}
			}
			classList.add(classes);
		}
		return classList;
		
	}
	
	public ArrayList<Map<String,Integer>> classOfAttrNo()
	{
		ArrayList<Map<String,Integer>> classList = new ArrayList<Map<String,Integer>>();
		
		for(int i=0;i<datas.get(0).size() - 1;i++)
		{
			String c = "";
			Map<String,Integer> classes = new HashMap<String,Integer>();
			for(int j=0;j<datas.size();j++)
			{
				String s = datas.get(j).get(datas.get(0).size()-1);
				if(s.equals("0"))
				{
					c = datas.get(j).get(i);
					if(classes.containsKey(c))
					{
						classes.put(c, classes.get(c) + 1);
					}else{
						classes.put(c, 1);
					}
				}
			}
			classList.add(classes);
		}
		return classList;
		
	}
	
	
	
	
	public ArrayList<ArrayList<Double>> getP(ArrayList<Double> average, ArrayList<Double> variance)
	{
		ArrayList<ArrayList<Double>> pList = new ArrayList<ArrayList<Double>>();
		for(int i=0; i<datas.get(0).size()-1;i++)
		{
			ArrayList<Double> pListCol = new ArrayList<Double>();
			double p = 0.00000;
			for(int j=0;j<datas.size();j++)
			{
				String s = datas.get(j).get(i);
				double number = Double.parseDouble(s);
				double coeffcient = 1 / Math.sqrt(2 * PI * variance.get(i));
				double exponential = - (number - average.get(i)) 
						* (number - average.get(i)) / (2 * variance.get(i));
				p = coeffcient * Math.exp(exponential);		
				pListCol.add(p);
			}
			pList.add(pListCol);
					
		}			
		return pList;		
	}
	
	public void print(ArrayList<ArrayList<Double>> pList)
	{
		//System.out.println("pList.size()" + pList.size());
		//System.out.println("pList.get(0).size()" + pList.get(0).size());
		for(int i=0;i<pList.size();i++)
		{
			
			for(int j=0;j<pList.get(0).size();j++)
			{
				System.out.println("pList.get("+i+ ").get("+j+ ")" + pList.get(i).get(j));
			}
		}
	}
	
	
	public ArrayList<ArrayList<Double>> getPYes(
			ArrayList<ArrayList<Double>> pList,
			Map<String,Integer> classes,
			ArrayList<Map<String,Integer>> classesAttr)
	{
		ArrayList<ArrayList<Double>> ppList = new ArrayList<ArrayList<Double>>();	
		
		int y = classes.get("1");
		int n = classes.get("0");		
		for(int i=0; i < datas.get(0).size()-1;i++)
		{
			ArrayList<Double> ppListCol = new ArrayList<Double>();
			double pp = 0.00000;
			double sumOfpp = 0.00000;
			for(int j=0; j < datas.size();j++)
			{
				double p = pList.get(i).get(j);
				
				String s = datas.get(j).get(i);
				System.out.println("s" + s);
				
				if(classesAttr.get(i).containsKey(s))
				{
					int m = classesAttr.get(i).get(s);
					pp = m / (p * (n + y));
				}else{
					pp = 0.000001 / (p * (n + y));				
				}

				sumOfpp += pp;				
			}
			
			for(int j=0; j < datas.size();j++)
			{
				double p = pList.get(i).get(j);	
				String s = datas.get(j).get(i);
				
				if(classesAttr.get(i).containsKey(s))
				{
					int m = classesAttr.get(i).get(s);
					pp = m / (p * (n + y));
				}else{
					pp = 0.0001 / (p * (n + y));		
				}

				
				pp = pp / sumOfpp;
				if(pp < 0.00001)
					pp = 1;
				ppListCol.add(pp);				
			}
			ppList.add(ppListCol);
		}
		
		return ppList;
		
	}
	
	
	public ArrayList<ArrayList<Double>> getPNo(
			ArrayList<ArrayList<Double>> pList,
			Map<String,Integer> classes,
			ArrayList<Map<String,Integer>> classesAttr)
	{
		ArrayList<ArrayList<Double>> ppList = new ArrayList<ArrayList<Double>>();
		
		int y = classes.get("1");
		int n = classes.get("0");
		
		for(int i=0; i < datas.get(0).size()-1;i++)
		{
			ArrayList<Double> ppListCol = new ArrayList<Double>();
			double pp = 0.00000;
			double sumOfpp = 0.00000;
			for(int j=0; j < datas.size();j++)
			{				
				double p = pList.get(i).get(j);	
				String s = datas.get(j).get(i);
				
				if(classesAttr.get(i).containsKey(s))
				{
					int m = classesAttr.get(i).get(s);
					pp = m / (p * (n + y));
				}else{
					pp = 0.0001 / (p * (n + y));				}
				
				int m = classesAttr.get(i).get(s);
				pp = m / (p * (n + y));
				
				
				sumOfpp += pp;
			}
			
			for(int j=0; j < datas.size();j++)
			{
				double p = pList.get(i).get(j);
				String s = datas.get(j).get(i);
				
				if(classesAttr.get(i).containsKey(s))
				{
					int m = classesAttr.get(i).get(s);
					pp = m / (p * (n + y));
				}else{
					pp = 0.00001 / (p * (n + y));				}
				
				
				int m = classesAttr.get(i).get(s);				
				pp = m / (p * (n + y));
	
				pp = pp / sumOfpp;		
				if(pp < 0.0001)
					pp = 1;
				ppListCol.add(pp);			
			}			
			ppList.add(ppListCol);
		}	
		return ppList;		
	}
	
	
	
	public ArrayList<Map<String,Double>> getAttrListP(ArrayList<ArrayList<Double>> ppList)
	{
		ArrayList<Map<String,Double>> attrListP = new ArrayList<Map<String,Double>>();
		
		for(int i=0;i<datas.get(0).size()-1;i++)
		{
			Map<String,Double> m = new HashMap<String,Double>();
			for(int j=0;j<datas.size();j++)
			{
				
				String s = datas.get(j).get(i);
				m.put(s,ppList.get(i).get(j));
			}
			attrListP.add(m);
		}
		return attrListP ;	
	}
	
	
	
	
	public void getResult(
			ArrayList<Map<String,Double>> attrListPYES,
			ArrayList<Map<String,Double>> attrListPNO,
			ArrayList<ArrayList<String>> tests,
			Map<String,Integer> classes)
	{
		int y = classes.get("1");
		int n = classes.get("0");
		double NP = n / (y+n);
		double YP = y / (y+n);
		
		for(int i=0;i<tests.size();i++)
		{
			String result = "";
			double pYES = 1.0;
			double pNO = 1.0;
			for(int j=0;j<tests.get(0).size();j++)
			{
				String value = tests.get(i).get(j);
				//System.out.println(value);
				
				if(!attrListPYES.get(j).containsKey(value))
				{
					
					int N = 2;
					pYES *= 1 / N;					 
				}else{
					double p = attrListPYES.get(j).get(value);
					//System.out.println("p" + p);
					pYES *= p;
					System.out.println("pYES" + pYES);
				//	pYES += (Math.log(p) / Math.log(0.01));
					
				}
				
				if(!attrListPNO.get(j).containsKey(value))
				{
					int N = 2;
					pNO *= 1 / N;					 
				}else{
					double p = attrListPNO.get(j).get(value);
					pNO *= p;
					System.out.println("pNO" + pNO);
					//pNO += (Math.log(p) / Math.log(0.01));
					//System.out.println("no");
				}
			}
			//System.out.println("pYES * YP        "+ Math.exp(Math.log(YP) + 100));
			if(pYES  > pNO ) 
			{
				System.out.println("1");
			}else{
				System.out.println("0");
			}
		}
	}	*/
	





/*Iterator iter = m.entrySet().iterator();  
for (int k = 0; iter.hasNext(); k++) {  
    Map.Entry entry = (Map.Entry) iter.next();  
    String key = (String) entry.getKey();  
    double val = (Double) entry.getValue();  
    
    System.out.println("key:"+key);
    System.out.println("val:"+val);
   
}  */


版权声明:本文为u013058160原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。