[Java][机器学习]用决策树分类算法对Iris花数据集进行处理

  • Post author:
  • Post category:java


Iris Data Set是很经典的一个数据集,在很多地方都能看到,一般用于教学分类算法。这个数据集在UCI Machine Learning Repository里可以找到(还是下载量排第一的数据哟)。这个数据集里面,每个数据都包含4个值(sepal length,sepan width,petal length,petal width)以及其种类。而分类算法的目标,就是根据4个值来把花的种类确定下来。

分类的大概原理就是利用熵的变化来判断哪个属性最适合分类(这个书上都有详细的解释)。写这个算法花了我挺多的时间,有点出乎我的意料。整套代码共850行,这里面还包括将数据从我的数据库取出来以及将数据输出到excel文件的代码。写完之后想想,还是有很多地方写的不够紧凑,以后改进吧。

如前所说,我事先把数据集存储在SQL数据库里面,然后有些结果会输出到桌面上的”text.xls”文件里面。下面上代码。整个工程共8个类,分别为IrisData,IrisInfoGet,IrisNode,Hunt,Estimate,ExcelPrint,DataProperty,test。另外还需要两个外部类,分别为sqljdbc.jar(负责SQL连接)以及jxl.jar(EXCEL连接)


test 程序的入口



IrisData 数据的基本结构



IrisInfoGet 负责从SQL中提取数据并转化成IrisData数组的形式。



IrisNode 负责用训练集生成决策树。



Hunt 负责寻找当前决策树节点最适合分类的属性,是决策树算法的关键组成部分。



Estimate 负责用检测集来检测决策树的性能。



ExcelPrint 将IrisData数组输出到excel文件里面。



DataProperty 是个辅助类,用于计算数组的熵等。


1.IrisData


public class IrisData {
    public double SL,SW,PL,PW ; //Sepal Length/Width, Petal Length/Width
    public int Type;  //Iris-setosa:0 Iris-versicolor:1 Iris-virginica:2
    public int tempType ;
    public int SetNum ;  //IrisSet的编号

    public IrisData(){
        this.SL=-1 ;    //Sepal Length
        this.SW=-1 ;    //Sepal Width
        this.PL=-1 ;    //Petal Length
        this.PW=-1 ;    //Petal Width
        this.Type=-1 ;  //The type of the flower
        this.tempType=0 ;
        this.SetNum=-1 ; 
    }
    public IrisData(double SL,double SW,double PL,double PW,int Type,int SetNum){
        this.SL=SL ;
        this.SW=SW ;
        this.PL=PL ;
        this.PW=PW ;
        this.Type=Type ;
        this.tempType=-1 ; //tempType=-1 means undefined
        this.SetNum=SetNum ;
    }
}

除了原来就有的SL,SW,PL,PW,Type等值外,还多了SetNum以及tempType。其中SetNum指的是当前数据的序号,在实际应用中没什么用,但是tempType很关键。由于不知道如何直接将众多元素分成多类,我采取的方法是先将一种(假定为0号花)分离出来,然后再将1号和2号花分离。那么这是就需要一个临时属性。比如在将0号分离时,0号花的tempType=0,而1,2号花的tempType=1.而在分离1和2时,1号花的tempType=0,2号花的tempType=1.


2.IrisInfoGet


/*import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
*/
import java.sql.* ;
/**
 * @function get data from SQL Server as the type of IrisData
 * @author multiangle from Southeast University
 */
public class IrisInfoGet {

    public IrisData[] dataset ;

    public IrisInfoGet(){
        ResultSet rs=getResultSet() ;
        this.dataset=ResultDeal(rs) ;
    }

    public static void main(String[] args) throws SQLException{
        ResultSet rs=getResultSet();
        IrisData[] data=ResultDeal(rs) ;
    }

    private static ResultSet getResultSet(){
        String JDriver="com.microsoft.sqlserver.jdbc.SQLServerDriver";//SQL数据库引擎
        String connectDB="jdbc:sqlserver://127.0.0.1:1433;DatabaseName=multiangle";//数据源

        try{
            Class.forName(JDriver);//加载数据库引擎,返回给定字符串名的类
            System.out.println("数据库驱动成功");
        }catch(ClassNotFoundException e){ //e.printStackTrace();
            System.out.println("加载数据库引擎失败");
            System.out.println(e);
        }     

        ResultSet rs ;
        try{
            String user="sa" ;
            String password="admin" ;
            Connection con=DriverManager.getConnection(connectDB,user,password);
            System.out.println("数据库连接成功");
            Statement stmt=con.createStatement() ;
            String query="select ROW_NUMBER()over(order by class)as row,* from dbo.[bezdekIris.data]" ;
            rs=stmt.executeQuery(query) ;
            return rs ;
        }catch(SQLException e){
            System.out.println(e) ;
            System.out.println("数据库内容读取失败");
            return null ;
        }
    }

    public static IrisData[] ResultDeal(ResultSet rs){
        IrisData[] dataset=new IrisData[150] ;
        int num=0 ;
        try {
            while((num<150)&&(rs.next())){
                double SL=Double.parseDouble(rs.getString("SepalLength")) ;
                double SW=Double.parseDouble(rs.getString("SepalWidth")) ;
                double PL=Double.parseDouble(rs.getString("PetalLength")) ;
                double PW=Double.parseDouble(rs.getString("PetalWidth")) ;
                int setnum=Integer.parseInt(rs.getString("row")) ;

                String name=rs.getString("Class") ;
                int type ;
                if(name.equals("Iris-setosa")) type=0 ;
                else if(name.equals("Iris-versicolor")) type=1 ;
                else if(name.equals("Iris-virginica")) type=2 ;
                else type=-1 ;

                dataset[num++]=new IrisData(SL,SW,PL,PW,type,setnum) ;
                //System.out.println(setnum+"       "+SL+"      "+SW+"      "+PL+"      "+PW+"      "+type) ;
            }
            System.out.println("ResultSet 解析完毕");
            return dataset ;
        } catch (SQLException e) {
            System.out.println("ResultSet 解析出错");
            System.out.println(e);
            return null ;
        }
    }
}

负责从SQL读取数据并返回IrisData[]形式。其中的getResultSet()返回的是ResultSet格式,然后由ResultDeal()处理以后返回IrisData[]形式。关于如何从SQL Server读取数据,我之前的博客里有写(也是极端新手向)


3.IrisNode

/**
 * @function the node of IrisTree
 * @author multiangle from SoutheastUniversity
 */
public class IrisNode {
    //Elements for Node itself
    public int deep ;               // the deep of the IrisNode tree 
    public double formerEntropy ;   // the entropy of the list belong to the node
    public IrisData[] datalist ;    // the data list belong to the node
    public String tag ;             // in order to research the node tree
    public int nodeType=-1 ;            //nodeType=-1 means it's not leaf node =0 means it belongs to class0(tempType) =1 belongs to class1

    public int divideType=-1 ;         // the attritube selected to divide the IrisData list
    public double valveValue=-1 ;       // the corresponding value of the attribute to divide

    //Elements for the child for the node
    public IrisNode leftChild=null ;        // the left child of the node
    public IrisNode rightChild=null ;   // the right child of the node
    public double laterEntropy=-1 ; // the total entropy of the two node after division
    public double deltaEntropy=0 ;  // the change of entropy between the ahead and after division

    //Methods in Node class
    public IrisNode(IrisData[] input,int deep,String tag){  //Construction Method

        this.tag=tag ;
        this.deep=deep ;
        this.datalist=input ;
        this.formerEntropy=getIrisDataListEntropy(input) ;
        this.nodeType=-1 ;

        if ((this.deep>5)||(this.datalist.length<2)){
            this.leftChild=this.rightChild=null ;
            int temp=decideType(this.datalist) ;
            if ((temp==0)||(temp==1)) this.nodeType=temp ;
            else System.out.println("ERROR:函数decideType输出值不合法") ;
        }else{
            Hunt hunt=new Hunt(input) ;
            this.divideType=hunt.type ;
            this.valveValue=hunt.value_value ;
            this.laterEntropy=hunt.min_entropy ;
            this.deltaEntropy=this.formerEntropy-this.laterEntropy ;

            if ((this.formerEntropy-this.laterEntropy)<0.05){
                this.leftChild=this.rightChild=null ; //if deltaEntropy<0.05 or deep>5 no longer continue
                int temp=decideType(this.datalist) ;
                if ((temp==0)||(temp==1)) this.nodeType=temp ;
                else System.out.println("ERROR:函数decideType输出值不合法") ;
            }else{
                //System.out.println("tag1") ;              //used for debug
                IrisData[] leftList=Divide(input,this.divideType,this.valveValue,0) ;
                IrisData[] rightList=Divide(input,this.divideType,this.valveValue,1) ;

                if ((leftList.length==0)||(rightList.length==0)) {
                    this.leftChild=this.rightChild=null ;
                    int temp=decideType(this.datalist) ;
                    if ((temp==0)||(temp==1)) this.nodeType=temp ;
                    else System.out.println("ERROR:函数decideType输出值不合法") ;
                }
                else{
                    this.leftChild=new IrisNode(leftList,deep+1,tag+'0') ;
                    this.rightChild=new IrisNode(rightList,deep+1,tag+'1') ;
                }
            }
        }
    } 

    public static IrisData[] Divide(IrisData[] input,int attribute,double valve,int methodtype){
        IrisData[] rs=null ;
        //通过attribute value type来将input分成两部分
        if (methodtype==0){ //此处为methodtype=1时的情况,也就是attr value<valve的情况
            int num=0 ;
            for(int i=0;i<input.length;i++){
                double tempvalue=-1 ;    //tempvalue初始值为=-1  在复用时要注意一下
                switch(attribute){
                case 0: {tempvalue=input[i].SL;break;} 
                case 1: {tempvalue=input[i].SW;break;} 
                case 2: {tempvalue=input[i].PL;break;} 
                case 3: {tempvalue=input[i].PW;break;} 
                default: System.out.println("ERROR:The value of attribute value illegal");
                }
                if(tempvalue<=valve) num++ ;
            }
            rs=new IrisData[num] ;
            int index=0 ;
            for(int i=0;i<input.length;i++){
                double tempvalue=-1 ;
                switch(attribute){
                case 0: {tempvalue=input[i].SL;break;} 
                case 1: {tempvalue=input[i].SW;break;} 
                case 2: {tempvalue=input[i].PL;break;} 
                case 3: {tempvalue=input[i].PW;break;} 
                }
                if (tempvalue<=valve) rs[index++]=input[i] ;
            }
            return rs ;
        }else if(methodtype==1){
            int num=0 ;
            for(int i=0;i<input.length;i++){
                double tempvalue=-1 ;    //tempvalue初始值为=-1  在复用时要注意一下
                switch(attribute){
                case 0: {tempvalue=input[i].SL;break;} 
                case 1: {tempvalue=input[i].SW;break;} 
                case 2: {tempvalue=input[i].PL;break;} 
                case 3: {tempvalue=input[i].PW;break;} 
                default: System.out.println("ERROR:The value of attribute value illegal");
                }
                if(tempvalue>valve) num++ ;
            }
            rs=new IrisData[num] ;
            int index=0 ;
            for(int i=0;i<input.length;i++){
                double tempvalue=-1 ;
                switch(attribute){
                case 0: {tempvalue=input[i].SL;break;} 
                case 1: {tempvalue=input[i].SW;break;} 
                case 2: {tempvalue=input[i].PL;break;} 
                case 3: {tempvalue=input[i].PW;break;} 
                }
                if (tempvalue>valve) {rs[index++]=input[i] ;}
            }
            return rs ;
        }else System.out.println("ERROR:methodtype value illegal");
        return rs ;
    }


    //------Private Method-----------------------
    private static int decideType(IrisData[] input){  //decide which class this node belongs to
        int rs=-1 ;
        int num0=0,num1=0 ;
        for(int i=0;i<input.length;i++){
            if (input[i].tempType==0) num0++ ;
            if (input[i].tempType==1) num1++ ;
        }
        if (num0<num1) rs=1 ; //有条件的话可以吧num0=num1时node的归属用随机数来实现
        else rs=0 ;  

        return rs ;
    }

    private static double getIrisDataListEntropy(IrisData[] input){ 
        DataProperty dp=new DataProperty() ;
        double rs_entropy=-1 ;
        //通过tempType的值来计算irisdata数组的熵
        //tempType只有3个值,0表示类1,1表示类2,-1表示其他类 一般用于表示异常
        int num1=0,num2=0 ;
        for(int i=0;i<input.length;i++){
            if(input[i].tempType==0) num1++ ;
            if(input[i].tempType==1) num2++ ;
        }
        rs_entropy=dp.getEntropy(num1, num2) ;
        return rs_entropy ;
    }

}


在IrisNode类的几个元素里,Node.deep表示该节点的深度。根节点的深度为0.


datalist

表示该节点内的IrisData数组。


formerEntropy

表示分类前的数组(dtalist)的熵。


tag

表示从根节点到当前节点的路径。


nodeType

表示该节点的类。nodeType=-1时表示这是个页节点。=0或1时表示这是个叶节 点,=0表示这个节点内大部分数据的分类是tempType=0;=1表示大部分节点的tempType=1 .


divideType

指最佳的分类属性,值从0-3,分别表示4个属性值。


valveValue

指最佳属性的最佳分类值。小于valveValue的分到左子树,大于valvaValue的分到右子树。


leftChild,rightChild

顾名思义,左子树和右子树。


构造函数的结构

if (节点深度>5)且(节点数组长度小于2)     //此时为叶节点
        左右子树为空
        判断叶节点的类型(属于类0还是类1else
        寻找最合适的分类属性以及分类值。(Hunt算法)
        if (分类后熵的变化小于0.05) 则左右字数为空,判断叶节点类型
        else
                根据分类属性以及分类值来划分左右子树。
                if(左子树或右子树长度为0) 则左右子树为空,判断叶节点类型
                else 以左右子树再次调用构造函数,形成递归。
                end if
        end if
end if


4.Hunt

/**
 *@author multiangle from SoutheastUniversity
 *@function Hunt Method is used to get the best attribute and the best value
 *          to divide a node into two parts 
 */
public class Hunt {

    public double min_entropy ;
    public double value_value ;
    public int type ;

    public Hunt(IrisData[] dataset){
        //1. calculate the entropy of initial dataset
        //2. find best attritube from 4 
        double[][] rs=new double[4][2] ;

        int mintype=0 ;
        double minentropy=2 ;
        double valve_value=-1 ;

        for(int i=0;i<4;i++){
            rs[i]=FindBestValve(preDeal(dataset,i)) ;

            if(rs[i][0]<minentropy){
                minentropy=rs[i][0] ;
                valve_value=rs[i][1] ;
                mintype=i ;
            }
        }
        //3. find the best one and output
        this.min_entropy=minentropy ;
        this.value_value=valve_value ;
        this.type=mintype ;
    }

    private static double[][] preDeal(IrisData[] dataset,int type){  //transfer IrisData[] to int[][] to fit the followign processing
        if ((type<4)&&(type>=0)){
            double[][] rs=new double[dataset.length][3] ; //3 attributes:Number,Attribute Value,Type
            for(int i=0;i<dataset.length;i++){
                rs[i][1]=dataset[i].SetNum ;
                rs[i][2]=dataset[i].tempType ;           //ATTENTION the taken value is tempTyoe!
                switch(type){
                case 0:{rs[i][0]=dataset[i].SL ;break ;} //0 means sepal length
                case 1:{rs[i][0]=dataset[i].SW ;break ;} //1 means sepal width
                case 2:{rs[i][0]=dataset[i].PL ;break ;} //2 means petal length
                case 3:{rs[i][0]=dataset[i].PW ;break ;} //3 means petal width
                }
            }
            return rs ;
        }else {System.out.println("ERROR:type输入值不正确");return null ;}
    }

    private static double[] FindBestValve(double[][] input){  
        //要考虑Type的多值性,最好只有两个值
        double[][] sorted=QuickSort(input,0,input.length-1) ; //1st step:sort the input array
        //接下来应该要在不同值区间内循环,挑一个熵值最小的。
        double min_entropy=2 ;
        double valve_value=-1 ;
        for(int i=0;i<sorted.length-1;i++){
            // calculate the entropy of the division whose valve is between i and i+1
            if (sorted[i][0]!=sorted[i+1][0]){      //避免在两个相同值之间分析的情况
                double temp_entropy=CalculateEntropy(sorted,i) ; 
                if (temp_entropy<min_entropy){
                    min_entropy=temp_entropy ;
                    valve_value=(sorted[i][0]+sorted[i+1][0])/2 ;
                }
            }
        }
        double[] rs=new double[2] ;
        rs[0]=min_entropy ;
        rs[1]=valve_value ;
        return rs ;
    }

    private static double CalculateEntropy(double[][] sorted,int i) {  //can only deal with the data which have only two classes
        DataProperty dp=new DataProperty() ; //initialization of dataproperty       
        double rs_entropy=-1 ;      
        int num1=0 ;
        int num2=0 ;
        for(int x=0;x<i+1;x++){
            if(sorted[x][2]==0) num1++ ;
            else if(sorted[x][2]==1) num2++ ;
            else System.out.println("ERROR from CalculateEntropy: the value of tempType of a item is -1");
        }
        double entropy1=dp.getEntropy(num1,num2) ;
        int tnum1=num1+num2 ; //total number of the former sequence

        num1=0 ;
        num2=0 ;
        for(int x=i+1;x<sorted.length;x++){
            if(sorted[x][2]==0) num1++ ;
            else if(sorted[x][2]==1) num2++ ;
            else System.out.println("ERROR from CalculateEntropy: the value of tempType of a item is -1");  
        }
        double entropy2=dp.getEntropy(num1,num2) ;
        int tnum2=num1+num2 ;
        rs_entropy=(entropy1*tnum1+entropy2*tnum2)/(tnum1+tnum2) ;
        return rs_entropy ;
    } 

    private static double[][] QuickSort(double[][] input,int low,int high){
        if(low>=high) return null ;
        int first=low ;
        int last=high ;
        double[] key=input[low] ;
        while(first<last){
            while((first<last)&&(input[last][0]>=key[0])) --last ;
            input[first]=input[last] ;
            while((first<last)&&(input[first][0]<=key[0])) ++first ;
            input[last]=input[first] ;
        }
        input[first]=key ;

        double[][] res1,res2 ;

        if (first-1>low) {res1=QuickSort(input,low,first-1) ;}
        else if(first-1==low) {double[][] temp={input[low]} ;res1=temp ;}
        else{res1=null ;}

        if(high>first+1){res2=QuickSort(input,first+1,high) ;}
        else if(high==first+1){double[][] temp={input[high]} ;res2=temp ;}
        else{res2=null ;}

        double[][] finalres ;
        finalres=Combine(res1,res2,key) ;
        return finalres ;
    }
    private static double[][] Combine(double[][] res1,double[][] res2,double[] key){
        int len1,len2 ;
        if(res1==null) len1=0 ;
        else len1=res1.length ;
        if(res2==null) len2=0 ;
        else len2=res2.length ;

        double[][] res=new double[len1+len2+1][3] ;
        int index=0 ;
        for(int i=0;i<len1;i++) res[index++]=res1[i] ;
        res[index++]=key ;
        for(int i=0;i<len2;i++) res[index++]=res2[i] ;
        return res ;
    }

    //-------调试用函数----------------------
    private static void print(double[][] input){     used for debug
        if(input!=null){
            int len1=input.length ;
            int len2=input[0].length ;
            for(int i=0;i<len1;i++){
                for(int j=0;j<len2;j++){
                    System.out.print(input[i][j]+"\t");
                }
                System.out.print('\n');
            }
        }else System.out.println("ERROR:输入二维数组为空") ;

    }
}   

Hunt类有3个值,type指的是计算得到的最优的分类属性,valve_value指的是相应的最优分类值,min_entropy指的是分类以后的熵。

为了构造的函数,主要就是先用preDeal函数将假设的属性值与类绑定到一个二维数组中(这么做是为了增加通用性)然后使用FindBestValue来寻找该属性中最优值。即首先将这些元素排序,然后依次检验元素之间的差值对应的熵,从中选出最小的熵作为BestValue。这样,就能得到这个属性中最优差值对应的最小的熵。把4个属性对应的4个最小熵对比,就能得到熵最小的那个类,作为最优分类属性


5.Estimate

import java.util.*;

public class Estimate {
    ArrayList<IrisData> list0 ;
    ArrayList<IrisData> list1 ;
    IrisNode examtree ;
    double ErrorRatio ;

    public Estimate(IrisNode rule,IrisData[] examset){
        this.list0=new ArrayList() ;
        this.list1=new ArrayList() ;
        this.examtree=examTree(rule,examset) ;
        this.ErrorRatio=getErrorRatio(this.examtree) ;
    }

    private double getErrorRatio(IrisNode node){
        if (node.datalist.length==0) return 0 ;
        if(node.nodeType==-1){
            double len1=0,len2=0 ;
            double ratio1=1,ratio2=1 ;
            if(node.leftChild==null) len1=0 ;
            else {
                len1=node.leftChild.datalist.length ;
                ratio1=getErrorRatio(node.leftChild) ;
            }
            if(node.rightChild==null) len2=0 ;
            else {
                len2=node.rightChild.datalist.length ;
                ratio2=getErrorRatio(node.rightChild) ;
            }
            double noderatio=(len1*ratio1+len2*ratio2)/(len1+len2) ;
            return noderatio ;
        }else{
            if(node.nodeType==0){
                double len=node.datalist.length ;
                double num=0 ;
                for(int i=0;i<len;i++){
                    this.list0.add(node.datalist[i]) ;
                    if(node.datalist[i].tempType==1) num++ ;
                }
                double noderatio=num/len ;
                return noderatio ;
            }
            if(node.nodeType==1){
                double len=node.datalist.length ;
                double num=0 ;
                for(int i=0;i<len;i++){
                    this.list1.add(node.datalist[i]) ;
                    if(node.datalist[i].tempType==0) num++ ;
                }
                double noderatio=num/len ;
                return noderatio ;
            }
            return -1 ;
        }
    }

    private IrisNode examTree(IrisNode node,IrisData[] data){
        node.datalist=data ;
        node.formerEntropy=getIrisDataListEntropy(data) ;
        if (node.nodeType==-1) { //this node is not a leaf node
            IrisData[] left=IrisNode.Divide(data, node.divideType, node.valveValue, 0) ;
            IrisData[] right=IrisNode.Divide(data, node.divideType, node.valveValue, 1) ;
            if (left.length==0) node.leftChild=null ;
            else node.leftChild=examTree(node.leftChild,left) ;
            if(right.length==0) node.rightChild=null ;
            else node.rightChild=examTree(node.rightChild,right) ;
            return node ;
        }else{    // this node is a leaf node
            node.leftChild=null ;
            node.rightChild=null ;
            return node ;
        }
    }

    public double getFinalEntropy(IrisNode input){
        double rs=-1 ;
        if ((input.leftChild==null)||(input.rightChild==null)){
            rs=getIrisDataListEntropy(input.datalist) ;
            return rs ;
        }else{
            double rs_1=getFinalEntropy(input.leftChild) ;
            double len1=input.leftChild.datalist.length ;
            double rs_2=getFinalEntropy(input.rightChild) ;
            double len2=input.rightChild.datalist.length ;
            rs=(rs_1*len1+rs_2*len2)/(len1+len2) ;
            return rs ;
        }
    }

    //Private Methods
    private static double getIrisDataListEntropy(IrisData[] input){ 
        DataProperty dp=new DataProperty() ;
        double rs_entropy=-1 ;
        //通过tempType的值来计算irisdata数组的熵
        //tempType只有3个值,0表示类1,1表示类2,-1表示其他类 一般用于表示异常
        int num1=0,num2=0 ;
        for(int i=0;i<input.length;i++){
            if(input[i].tempType==0) num1++ ;
            if(input[i].tempType==1) num2++ ;
        }
        rs_entropy=dp.getEntropy(num1, num2) ;
        return rs_entropy ;
    }

}

这个类主要用于评估生成的决策树的性能,目前不是很完善,只能计算检验集使用该决策树的错误率。

examTree将检验集应用于之前生成的决策树,重新生成节点的左右子树,对于空的子树,进行封闭处理。

getErrorRatio利用递归的方法计算错误率。


6.ExcelPrint

import java.io.File;

import jxl.Workbook;
import jxl.write.Label;
import jxl.write.WritableSheet;
import jxl.write.WritableWorkbook;


public class ExcelPrint {
    public ExcelPrint(){

    }

    public void PrintIrisDataArray(IrisData[] input,String filename){
        try{

            String rootname="C:\\Users\\multiangle\\Desktop\\" ;  
            String path=rootname+filename+".xls" ;
            File file=new File(path) ;
            WritableSheet sheet ;
            WritableWorkbook book ;
            if (file.exists()) {
                Workbook wb=Workbook.getWorkbook(file) ;
                book=Workbook.createWorkbook(file, wb) ;
                int sheetnum=book.getNumberOfSheets() ;
                sheet=book.createSheet("第"+sheetnum+"页", sheetnum) ;
                System.out.println("正在第"+sheetnum+"页打印IrisData数组");
            }else {
                book=Workbook.createWorkbook(new File(path)) ;
                sheet=book.createSheet("第0页", 0) ;
                System.out.println("正在第0页打印IrisData数组");
            }
            //System.out.println("已获取到需要的表单");

            String[] name={"SetNum","Sepal Length","Sepan Width","Petal Length","Petal Width","Type","tempType"} ;
            for(int i=0;i<7;i++){
                Label temp=new Label(i,0,name[i]) ;
                sheet.addCell(temp);
            }

            int len=input.length ;
            int row=1 ;
            for(int i=0;i<len;i++){
                int col=0 ;
                Label cell1=new Label(col++,row,String.valueOf(input[i].SetNum)) ;
                Label cell2=new Label(col++,row,String.valueOf(input[i].SL)) ;
                Label cell3=new Label(col++,row,String.valueOf(input[i].SW)) ;
                Label cell4=new Label(col++,row,String.valueOf(input[i].PL)) ;
                Label cell5=new Label(col++,row,String.valueOf(input[i].PW)) ;
                Label cell6=new Label(col++,row,String.valueOf(input[i].Type)) ;
                Label cell7=new Label(col++,row,String.valueOf(input[i].tempType)) ;
                sheet.addCell(cell1);
                sheet.addCell(cell2);
                sheet.addCell(cell3);
                sheet.addCell(cell4);
                sheet.addCell(cell5);
                sheet.addCell(cell6);
                sheet.addCell(cell7);

                row++ ;
            }
            book.write() ;
            book.close(); 
        }catch(Exception e){
            System.out.println(e) ;
            System.out.println("ERROR:ExcelPrint") ;
        }

    }

    public void PrintIrisDataArray(IrisData[] input,String filename,String description){
        try{

            String rootname="C:\\Users\\multiangle\\Desktop\\" ;  
            String path=rootname+filename+".xls" ;
            File file=new File(path) ;
            WritableSheet sheet ;
            WritableWorkbook book ;
            if (file.exists()) {
                Workbook wb=Workbook.getWorkbook(file) ;
                book=Workbook.createWorkbook(file, wb) ;
                int sheetnum=book.getNumberOfSheets() ;
                sheet=book.createSheet("第"+sheetnum+"页", sheetnum) ;
                System.out.println("正在第"+sheetnum+"页打印IrisData数组");
            }else {
                book=Workbook.createWorkbook(new File(path)) ;
                sheet=book.createSheet("第0页", 0) ;
                System.out.println("正在第0页打印IrisData数组");
            }
            //System.out.println("已获取到需要的表单");

            Label descrip=new Label(0,0,description) ;
            sheet.addCell(descrip) ;
            String[] name={"SetNum","Sepal Length","Sepan Width","Petal Length","Petal Width","Type","tempType"} ;
            for(int i=0;i<7;i++){
                Label temp=new Label(i,1,name[i]) ;
                sheet.addCell(temp);
            }

            int len=input.length ;
            int row=2 ;
            for(int i=0;i<len;i++){
                int col=0 ;
                Label cell1=new Label(col++,row,String.valueOf(input[i].SetNum)) ;
                Label cell2=new Label(col++,row,String.valueOf(input[i].SL)) ;
                Label cell3=new Label(col++,row,String.valueOf(input[i].SW)) ;
                Label cell4=new Label(col++,row,String.valueOf(input[i].PL)) ;
                Label cell5=new Label(col++,row,String.valueOf(input[i].PW)) ;
                Label cell6=new Label(col++,row,String.valueOf(input[i].Type)) ;
                Label cell7=new Label(col++,row,String.valueOf(input[i].tempType)) ;
                sheet.addCell(cell1);
                sheet.addCell(cell2);
                sheet.addCell(cell3);
                sheet.addCell(cell4);
                sheet.addCell(cell5);
                sheet.addCell(cell6);
                sheet.addCell(cell7);

                row++ ;
            }
            book.write() ;
            book.close(); 
        }catch(Exception e){
            System.out.println(e) ;
            System.out.println("ERROR:ExcelPrint") ;
        }

    }

    public void PrintIrisData(IrisData[] input,String filename,int sheetnum){ //要注意 input sheet num>current sheet num+1的情况可能出现的BUG
        //待定
    }

    private static boolean CreateExcel(String filename){
        try{
            WritableWorkbook book=Workbook.createWorkbook(new File(filename));   //打开文件 
            WritableSheet sheet=book.createSheet("FirstPage",0) ;  //生成名为“FirstPage”的工作表,参数0表示这是第一页 
            Label label=new Label(0,0,"") ;
            sheet.addCell(label);
            book.write();
            book.close();
            return true;
        }catch(Exception e){
            System.out.println("ERROR:CreateExcel") ;
            return false ;
        }
    }

    private static boolean CreateExcel(String filename,String sheetname){
        try{
            WritableWorkbook book=Workbook.createWorkbook(new File(filename));   //打开文件
            WritableSheet sheet=book.createSheet(sheetname,0) ;  //生成名为sheetname的工作表,参数0表示这是第一页 
            Label label=new Label(0,0,"") ;
            sheet.addCell(label);
            book.write();
            book.close(); 
            return true ;
        }catch(Exception e){
            System.out.println("ERROR:CreateExcel") ;
            return false ;
        }
    }

    private static void WriteExcel_Cell(String filename,String value,int cow,int row){
        try {
            // Excel获得文件
            System.out.println("0");
            File file=new File(filename) ;
            Workbook wb = Workbook.getWorkbook(file);
            // 打开一个文件的副本,并且指定数据写回到原文件
            WritableWorkbook book = Workbook.createWorkbook(file,wb);
            System.out.println("1");
            // 添加一个工作表
            WritableSheet sheet = book.getSheet(0) ;  //book.createSheet("第二页", 1); 
            sheet.addCell(new Label(cow, row, value));
            book.write();
            book.close();
        } catch (Exception e) {
            System.out.println(e);
            System.out.println("ERROR:写入失败");
        }
    }



}

用于将IrisData数组输出。这个基本按照之前的Excel操作就能完成,没啥可说的。需要外部类jxl.jar


7.DataProperty


public class DataProperty {

    public double getGini(int[] data){
        int len=data.length ;
        double sum=0 ;
        for(int i=0;i<len;i++)  sum+=data[i] ;
        double pre_gini=0 ;
        for(int i=0;i<len;i++)  pre_gini+= (data[i]/sum)*(data[i]/sum) ;
        double gini=1-pre_gini ;
        return gini ;
    }
    public double getGini(int a,int b){
        double c=a+b ;
        double gini=1-(a/c)*(a/c)-(b/c)*(b/c) ;
        return gini ;
    }

    public double getEntropy(int[] data){
        int len=data.length ;
        double sum=0 ;
        for(int i=0;i<len;i++)  sum+=data[i] ;  //get the summary of all data
        double pre_entro=0 ;
        for(int i=0;i<len;i++) {
            if (data[i]!=0){
                pre_entro+=(data[i]/sum)*Math.log(data[i]/sum)/Math.log(2) ;
            }
        }
        double entro=-pre_entro ;
        return entro ;
    }
    public double getEntropy(int ina,int inb){
        double a=(double)ina ;
        double b=(double)inb ;
        double entro ;
        if((a*b)!=0){
            double c=a+b ;
            double a1=(a/c)*mathLog2(a/c) ;
            double b1=(b/c)*mathLog2(b/c) ;
            entro=-a1-b1 ;
            return entro ;
        }else{  
            entro=0 ;
            return entro ;
        }
    }

    //inner methods----------------------------------------------------
    private static double mathLog(double data,double bottom){
        return Math.log(data)/Math.log(bottom) ;
    }
    private static double mathLog2(double data){
        return Math.log(data)/Math.log(2) ;
    }
}

辅助类,用来计算数组的熵,或者两个数的熵


8.test

/**
 *@author multiangle from SoutheastUniversity
 *@function Hunt Method is used to get the best attribute and the best value
 *          to divide a node into two parts 
 */
public class test {
    public static void main(String[] args){


        IrisData[] dataset=new IrisInfoGet().dataset ;

        ExcelPrint ep=new ExcelPrint() ;

        ep.PrintIrisDataArray(dataset, "test","从数据库搬来的原始数据");

        System.out.println("首先将0与1,2分开来");
        for(int i=0;i<dataset.length;i++){
            if (dataset[i].Type==0) dataset[i].tempType=0 ;
            else dataset[i].tempType=1 ;
        }
        ep.PrintIrisDataArray(dataset, "test","加tempType以后的的dataset");

        //此处完成训练集和检验集的完成。训练集的比例可能比较低,
        //这个不是问题,以后可以随意更改的。
        IrisData[] trainset=new IrisData[75] ;  
        IrisData[] examset=new IrisData[75] ;   
        int index=0 ;
        for(int i=0;i<75;i++){
            trainset[index]=dataset[index*2] ;
            examset[index]=dataset[index*2+1] ;
            index++ ;
        }
        ep.PrintIrisDataArray(trainset, "test", "训练集trainset");
        ep.PrintIrisDataArray(examset, "test", "检验集examset");

        IrisNode root=new IrisNode(trainset,0,"0") ;
        IrisNode node1=root ;
        Estimate es=new Estimate(node1,examset) ;
        System.out.println("得到的决策树使用检验集检验的错误率是: "+String.valueOf(es.ErrorRatio));

        System.out.println("然后将1与2分开来");
        IrisData[] ts2=root.rightChild.datalist ;
        for(int i=0;i<ts2.length;i++){
            if (ts2[i].Type==1) ts2[i].tempType=0 ;
            if (ts2[i].Type==2) ts2[i].tempType=1 ;
        }
        int es2len=0 ;
        for(int i=0;i<examset.length;i++){
            if((examset[i].Type==1)||(examset[i].Type==2)) es2len++ ;
        }
        IrisData[] es2=new IrisData[es2len] ;
        int es2index=0 ;
        for(int i=0;i<examset.length;i++){
            if (examset[i].Type==1) {
                es2[es2index]=examset[i] ;
                es2[es2index++].tempType=0 ;
            }
            if (examset[i].Type==2) {
                es2[es2index]=examset[i] ;
                es2[es2index++].tempType=1 ;
            }
        }
        IrisNode r2=new IrisNode(ts2,0,"0") ;
        Estimate estimate2=new Estimate(r2,es2) ;
        System.out.println("得到的决策树使用检验集检验的错误率是: "+String.valueOf(estimate2.ErrorRatio));

        IrisData[] es2class0=new IrisData[estimate2.list0.size()] ;
        for(int i=0;i<estimate2.list0.size();i++) es2class0[i]=estimate2.list0.get(i) ;
        ep.PrintIrisDataArray(es2class0, "test", "判定为typd1的标本");
        IrisData[] es2class1=new IrisData[estimate2.list1.size()] ;
        for(int i=0;i<estimate2.list1.size();i++) es2class1[i]=estimate2.list1.get(i) ;
        ep.PrintIrisDataArray(es2class1, "test", "判定为typd2的标本");

    }
}   

这个类是整套程序的入口。


IrisData[] dataset=new IrisInfoGet().dataset ;

//负责从数据库得到IrisData数组


ExcelPrint ep=new ExcelPrint() ;

ep.ep.PrintIrisDataArray(IrisData[] name, String filename, String description);

//负责将IrisData数组打印到桌面的filename中,description在第一行作为注释


IrisNode root=new IrisNode(IrisData[] setname, int initialdeep, String tag) ;

//输入IrisData数组,初始深度以及初始标签,得到根节点


Estimate es=new Estimate(IrisNode rule,IrisData[] examset) ;

//利用生成的决策树以及检验集,可以得到Estimate实例es, es.list0存储定义为类0的数据,es.list1存储定义为1的数据; es.ErrorRatio 为检测结果的错误率。



结语

:写这篇文章也算是我自己为这几天写的这个程序的回顾,发现还是有不少问题的,比如结构不够紧凑,分类太多,注释中英文夹杂不规范,变量名使用不规范,以及还有很多功能没有实现等等。这些都是以后有待改进的地方



版权声明:本文为u014595019原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。