Opencv2.4.9源码分析——Extremely randomized trees

  • Post author:
  • Post category:其他


一、原理

ET或Extra-Trees(

Ext

remely

ra

ndomized trees,极端随机树)是由PierreGeurts等人于2006年提出。该算法与随机森林算法十分相似,都是由许多决策树构成。但该算法与随机森林有两点主要的区别:

1、随机森林应用的是Bagging模型,而ET是使用所有的训练样本得到每棵决策树,也就是每棵决策树应用的是相同的全部训练样本;

2、随机森林是在一个随机子集内得到最佳分叉属性,而ET是完全随机的得到分叉值,从而实现对决策树进行分叉的。

对于第2点的不同,我们再做详细的介绍。我们仅以二叉树为例,当特征属性是类别的形式时,随机选择具有某些类别的样本为左分支,而把具有其他类别的样本作为右分支;当特征属性是数值的形式时,随机选择一个处于该特征属性的最大值和最小值之间的任意数,当样本的该特征属性值大于该值时,作为左分支,当小于该值时,作为右分支。这样就实现了在该特征属性下把样本随机分配到两个分支上的目的。然后计算此时的分叉值(如果特征属性是类别的形式,可以应用基尼指数;如果特征属性是数值的形式,可以应用均方误差)。遍历节点内的所有特征属性,按上述方法得到所有特征属性的分叉值,我们选择分叉值最大的那种形式实现对该节点的分叉。从上面的介绍可以看出,这种方法比随机森林的随机性更强。

对于某棵决策树,由于它的最佳分叉属性是随机选择的,因此用它的预测结果往往是不准确的,但多棵决策树组合在一起,就可以达到很好的预测效果。

当ET构建好了以后,我们也可以应用全部的训练样本来得到该ET的预测误差。这是因为尽管构建决策树和预测应用的是同一个训练样本集,但由于最佳分叉属性是随机选择的,所以我们仍然会得到完全不同的预测结果,用该预测结果就可以与样本的真实响应值比较,从而得到预测误差。如果与随机森林相类比的话,在ET中,全部训练样本都是OOB样本,所以计算ET的预测误差,也就是计算这个OOB误差。

在这里,我们仅仅介绍了ET算法与随机森林的不同之处,ET算法的其他内容(如预测、OOB误差的计算)与随机森林是完全相同的,具体内容请看关于随机森林的介绍。

二、源码分析

下面是ET算法的类CvERTrees,它继承于CvRTrees类:

class CV_EXPORTS_W CvERTrees : public CvRTrees
{
public:
    CV_WRAP CvERTrees();
    virtual ~CvERTrees();
    virtual bool train( const CvMat* trainData, int tflag,
                        const CvMat* responses, const CvMat* varIdx=0,
                        const CvMat* sampleIdx=0, const CvMat* varType=0,
                        const CvMat* missingDataMask=0,
                        CvRTParams params=CvRTParams());
    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
                       const cv::Mat& missingDataMask=cv::Mat(),
                       CvRTParams params=CvRTParams());
    virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
protected:
    virtual std::string getName() const;
    virtual bool grow_forest( const CvTermCriteria term_crit );
};

我们从CvERTrees类可以看出,它没有预测函数predict,因此,如果要进行ET的预测,调用的是它的父类CvRTrees内的predict函数。在训练样本时,CvERTrees类与CvRTrees类的训练过程是完全一致的,即在train函数内再调用grow_forest函数,而且两个类的train函数的输入参数的形式也是完全一样的。但在grow_forest函数内会有一点不同,那就是CvERTrees类中的grow_forest函数把全体训练样本都当成OOB样本,因此它不需要随机样本掩码矩阵变量sample_idx_mask_for_tree,而表示样本索引值变量的sample_idx_for_tree保存的就是正常顺序的训练样本的索引值。

ET算法与随机森林算法最大的不同就在于节点的分叉上,而这一点是体现在CvForestERTree类上的:

class CV_EXPORTS CvForestERTree : public CvForestTree
{
protected:
    virtual double calc_node_dir( CvDTreeNode* node );
    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    virtual void split_node_data( CvDTreeNode* n );
};

CvForestERTree类定义了一些专用于ET算法的计算分叉、得到最佳分叉属性的函数,下面我们就逐一介绍这些函数。

按最佳分叉属性标注该节点的所有样本是被分配到左分支还是右分支:

double CvForestERTree::calc_node_dir( CvDTreeNode* node )
{
    //表示特征属性的种类是属于左分支还是右分支,-1为左分支,1为右分支,如果该特征属性缺失,则为0
    char* dir = (char*)data->direction->data.ptr;
    //n表示该节点的样本数量,vi表示分类的最佳特征属性
    int i, n = node->sample_count, vi = node->split->var_idx;
    double L, R;

    assert( !node->split->inversed );    //确保分叉不反转

    if( data->get_var_type(vi) >= 0 ) // split on categorical var
    //表示该特征属性是种类的形式
    {
        //开辟一块内存空间
        cv::AutoBuffer<uchar> inn_buf(n*sizeof(int)*(!data->have_priors ? 1 : 2));
        int* labels_buf = (int*)(uchar*)inn_buf;
        //labels指向该特征属性中各个样本所对应的种类,get_cat_var_data函数在ER算法中被重新定义
        const int* labels = data->get_cat_var_data( node, vi, labels_buf );
        // subset数组的每一位表示特征属性的种类,左分支的种类位是1,右分支的是0
        const int* subset = node->split->subset;
        if( !data->have_priors )    //无先验概率
        {
            int sum = 0, sum_abs = 0;

            for( i = 0; i < n; i++ )    //遍历该节点的所有样本
            {
                int idx = labels[i];    //表示该样本的特征属性的种类
                //d为-1表示idx(特征属性的种类)属于左分支,为1表示属于右分支,如果没有该特征属性,则d为0
                int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
                    CV_DTREE_CAT_DIR(idx,subset) : 0;
                //sum表示d累加求和,因为d也可能为负值,所以sum的含义为右分支比左分支多出的特征属性种类;sum_abs表示d的绝对值之和,表示的含义为被分叉的特征属性种类
                sum += d; sum_abs += d & 1;
                dir[i] = (char)d;    //赋值
            }
            //L和R分别表示左右分支的特征属性的种类数量
            R = (sum_abs + sum) >> 1;
            L = (sum_abs - sum) >> 1;
        }
        else    //有先验概率
        {
            const double* priors = data->priors_mult->data.db;    //得到先验概率
            double sum = 0, sum_abs = 0;
            int *responses_buf = labels_buf + n;
            //responses指向该节点样本的分类,即响应值
            const int* responses = data->get_class_labels(node, responses_buf);

            for( i = 0; i < n; i++ )    //遍历该节点的所有样本
            {
                int idx = labels[i];    //表示该样本的特征属性的种类
                double w = priors[responses[i]];    //得到该响应值的先验概率
                //d为-1表示idx(特征属性的种类)属于左分支,为1表示属于右分支,如果没有该特征属性,则d为0
                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
                sum += d*w; sum_abs += (d & 1)*w;    //增加了先验概率
                dir[i] = (char)d;
            }
            //L和R分别表示左右分支的特征属性的种类数量
            R = (sum_abs + sum) * 0.5;
            L = (sum_abs - sum) * 0.5;
        }
    }
    else // split on ordered var
    //表示该特征属性是数值的形式
    {
        //得到分叉属性的值split_val,如果样本的分叉属性的值小于该值,则被分叉为左节点,否则为右节点
        float split_val = node->split->ord.c;
        //为该特征属性开辟一块内存空间
        cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(!data->have_priors ? 1 : 2) + sizeof(float)));
        float* val_buf = (float*)(uchar*)inn_buf;    //用于存储各个样本当前特征属性的值
        int* missing_buf = (int*)(val_buf + n);    //表示各个样本是否缺失当前特征属性
        const float* val = 0;    //指向数组val_buf
        const int* missing = 0;    //指向数组missing_buf
        // get_ord_var_data函数在ER算法中被重新定义,各个样本的vi特征属性的值存储在val_buf数组内,各个样本是否缺失该特征属性用missing_buf数组表示
        data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing, 0 );

        if( !data->have_priors )    //无先验概率
        {
            L = R = 0;
            for( i = 0; i < n; i++ )    //遍历所有样本
            {
                if ( missing[i] )    //该样本缺失vi这个特征属性
                    dir[i] = (char)0;    //方向信息赋值为0
                else
                {
                    if ( val[i] < split_val)    //左分支
                    {
                        dir[i] = (char)-1;    //方向信息赋值为-1
                        L++;    //左分支计数
                    }
                    else    //右分支
                    {
                        dir[i] = (char)1;    //方向信息赋值为1
                        R++;    //右分支计数
                    }
                }
            }
        }
        else    //有先验概率
  



版权声明:本文为zhaocj原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。