首页 > 技术文章 > ID3和C4.5的理论和应用

yutingmoran 2017-07-17 09:58 原文

 

本文是汽车评估系统的核心算法,利用决策树进行分类,本文对决策树进行了介绍,同时比较C4.5和ID3算法的不同,对C4.5提出随机深林的想法提高分类预测的准确性。

关键词:汽车评估,决策树,C4.5

决策树(Decision tree

它从一组无次序、无规则的元组中推理出决策树表示形式的分类规则。它采用自顶向下的递归方式,在决策树的内部结点进行属性值的比较,并根据不同的属性值从该结点向下分支,叶结点是要学习划分的类。从根到叶结点的一条路径就对应着一条合取规则,整个决策树就对应着一组析取表达式规则。1986年 Quinlan提出了著名的ID3算法。在ID3算法的基础上,1993年Quinlan又提出了C4.5算法。为了适应处理大规模数据集的需要,后来又提出了若干改进的算法,其中SLIQ(super-vised learning in quest)和SPRINT(scalable parallelizableinduction of decision trees)是比较有代表性的两个算法,本文主要用C4.5算法对汽车评估系统建立分类决策树,同时比较C4.5和ID3算法的不同,对C4.5提出随机深林的想法提高分类预测的准确性。

信息熵的含义及分类

信息熵是信息论中的一个重要的指标,是由香农在1948年提出的。香农借用了热力学中熵的概念来描述信息的不确定性。因此信息学中的熵和热力学的熵是有联系的。根据Charles H. Bennett对Maxwell’s Demon的重新解释,对信息的销毁是一个不可逆过程,所以销毁信息是符合热力学第二定律的。而产生信息,则是为系统引入负(热力学)熵的过程。所以信息熵的符号与热力学熵应该是相反的。

简单的说信息熵是衡量信息的指标,更确切的说是衡量信息的不确定性或混乱程度的指标。信息的不确定性越大,熵越大。决定信息的不确定性或者说复杂程度主要因素是概率。决策树中使用的与熵有关的概念有三个:信息熵,条件熵和互信息。下面分别来介绍这三个概念的含义和计算方法。

1信息熵

信息熵是用来衡量一元模型中信息不确定性的指标。信息的不确定性越大,熵的值也就越大。而影响熵值的主要因素是概率。这里所说的一元模型就是指单一事件,而不确定性是一个事件出现不同结果的可能性。例如抛硬币,可能出现的结果有两个,分别是正面和反面。而每次抛硬币的结果是一个非常不确定的信息。因为根据我们的经验或者历史数据来看,一个均匀的硬币出现正面和反面的概率相等,都是50%。因此很难判断下一次出现的是正面还是反面。这时抛硬币这个事件的熵值也很高。而如果历史数据告诉我们这枚硬币在过去的100次试验中99次都是正面,也就是说这枚硬币的质量不均匀,出现正面结果的概率很高。那么我们就很容易判断下一次的结果了。这时的熵值很低,只有0.08。

 

我们把抛硬币这个事件看做一个随机变量S,它可能的取值有2种,分别是正面x1和反面x2。每一种取值的概率分别为P1和P2。我们要获得随机变量S的取值结果至少要进行1次试验,试验次数与随机变量S可能的取值数量(2种)的对数函数Log有联系。Log2=1(以2为底),其计算公式是:

Pi为子集合中不同性(而二元分类即正样例和负样例)的样例的比例。

在抛硬币的例子中,我们借助一元模型自身的概率,也就是前100次的历史数据来消除了判断结果的不确定性。而对于很多现实生活中的问题,则无法仅仅通过自身概率来判断。例如:对于天气情况,我们无法像抛硬币一样通过晴天,雨天和雾霾在历史数据中出现的概率来判断明天的天气,因为天气的种类很多,并且影响天气的因素也有很多。同理,对于网站的用户我们也无法通过他们的历史购买频率来判断这个用户在下一次访问时是否会完成购买。因为用户是的购买行为存在着不确定性,要消除这些不确定性需要更多的信息。例如用户历史行为中的广告创意,促销活动,商品价格,配送时间等信息。因此这里我们不能只借助一元模型来进行判断和预测了,需要获得更多的信息并通过二元模型或更高阶的模型了解用户的购买行为与其他因素间的关系来消除不确定性。衡量这种关系的指标叫做条件熵。

2、条件熵

条件熵是通过获得更多的信息来消除一元模型中的不确定性。也就是通过二元或多元模型来降低一元模型的熵。我们知道的信息越多,信息的不确定性越小。例如,只使用一元模型时我们无法根据用户历史数据中的购买频率来判断这个用户本次是否也会购买。因为不确定性太大。在加入了促销活动,商品价格等信息后,在二元模型中我们可以发现用户购买与促销活动,或者商品价格变化之间的联系。并通过购买与促销活动一起出现的概率,和不同促销活动时购买出现的概率来降低不确定性。

 

计算条件熵时使用到了两种概率,分别是购买与促销活动的联合概率P(c),和不同促销活动出现时购买也出现的条件概率E(c)。以下是条件熵E(T,X)的计算公式。条件熵的值越低说明二元模型的不确定性越小。

3互信息(信息增益)

互信息是用来衡量信息之间相关性的指标。当两个信息完全相关时,互信息为1,不相关时为0。在前面的例子中用户购买与促销活动这两个信息间的相关性究竟有多高,我们可以通过互信息这个指标来度量。具体的计算方法就熵与条件熵之间的差。用户购买的熵E(T)减去促销活动出现时用户购买的熵E(T,X)。以下为计算公式:

熵,条件熵和互信息是构建决策树的三个关键的指标。下面我们将通过信息增益划分决策树,即ID3算法。

ID3算法

在决策树各级节点选择属性时,以信息熵增益(Information gain)作为属性的选择标准,在检测所有属性值时,选择信息熵增益最大的属性产生决策树的节点,由该属性的不同取值作为分支,然后递归建立分支,最后等到一个完整的决策树,可以用来对新的样本进行分类。

信息熵增益定义为样本按照某属性划分时造成熵减少的期望,可以区分训练样本中正负样本的能力,其公式是:

 

ID3算法存在的缺点

1)ID3算法在选择根节点和各内部节点中的分支属性时,采用信息增益作为评价标准。信息增益的缺点是倾向于选择取值较多的属性,在有些情况下这类属性可能不会提供太多有价值的信息。

2)ID3算法只能对描述属性为离散型属性的数据集构造决策树。

C4.5算法做出的改进

(1)用信息增益率来选择属性,选择信息增益率大的产生决策树的节点,克服了用信息增益来选择属性时偏向选择值多的属性的不足。信息增益率定义为:

 

分子表示信息增益,和ID3算法一样,分母表示分裂因子,代表按照属性A分裂样本集S的广度和均匀性。分裂因子公式如下:

 

(2)可以处理连续数值型属性

C4.5既可以处理离散型描述属性,也可以处理连续性描述属性。在选择某节点上的分枝属性时,对于离散型描述属性,C4.5的处理方法与ID3相同,按照该属性本身的取值个数进行计算;对于某个连续性描述属性Ac,假设在某个结点上的数据集的样本数量为total,C4.5将作以下处理。

Ø 将该结点上的所有数据样本按照连续型描述属性的具体数值,由小到大进行排序,得到属性值的取值序列{A1c,A2c,……Atotalc}。

Ø 在取值序列中生成total-1个分割点。第i(0<i<total)个分割点的取值设置为Vi=(Aic+A(i+1)c)/2,它可以将该节点上的数据集划分为两个子集。

Ø total-1个分割点中选择最佳分割点。对于每一个分割点划分数据集的方式,C4.5计算它的信息增益,并且从中选择信息增益比最大的分割点来划分数据集。

 

(3)采用悲观剪枝(Pessimistic Error Pruning (PEP)),避免树的高度无节制的增长,避免过度拟合数据。

  PEP后剪枝技术是由大师Quinlan提出的。它不需要像REP(错误率降低修剪)样,需要用部分样本作为测试数据,而是完全使用训练数据来生成决策树,又用这些训练数据来完成剪枝。决策树生成和剪枝都使用训练集, 所以会产生错分。现在我们先来介绍几个定义:

符号

含义

T1

决策树T的所有内部节点(非叶子节点)

T2

决策树T的所有叶子节点

T3

决策树T的所有节点,T3=T1∪T2

n(t)

节点t的所有样本数

ni(t)

节点t中类别i的所有样本数

e(t)

t中不属于节点t所标识类别的样本数

 

  在剪枝时,我们使用r(t)=e(t)/n(t),就是当节点被剪枝后在训练集上的错误率,而下面公式表示具体的计算公式,其中s为t节点的叶子节点。

  在此,我们把错误分布看成是二项式分布,由上面“二项分布的正态逼近”相关介绍知道,上面的式子是有偏差的,因此需要连续性修正因子来矫正数据, r‘(t)=[e(t) + 1/2]/n(t)

其中s为t节点的叶子节点,你不认识的那个符号为 t的所有叶子节点的数目。

  为了简单,我们就只使用错误数目而不是错误率了,如下e'(t) = [e(t) + 1/2]:

  接着求e'(Tt)的标准差,由于误差近似看成是二项式分布,根据u = np, σ2=npq可以得到

  

  当节点t满足下面公式是,Tt子树就会被剪掉:

(4)对于缺失值的处理

在某些情况下,可供使用的数据可能缺少某些属性的值。假如〈x,c(x)〉是样本集S中的一个训练实例,但是其属性A的值A(x)未知。处理缺少属性值的一种策略是赋给它结点n所对应的训练实例中该属性的最常见值;另外一种更复杂的策略是为A的每个可能值赋予一个概率。例如,给定一个布尔属性A,如果结点n包含6个已知A=1和4个A=0的实例,那么A(x)=1的概率是0.6,而A(x)=0的概率是0.4。于是,实例x的60%被分配到A=1的分支,40%被分配到另一个分支。这些片断样例(fractional examples)的目的是计算信息增益,另外,如果有第二个缺少值的属性必须被测试,这些样例可以在后继的树分支中被进一步细分。C4.5就是使用这种方法处理缺少的属性值。

C4.5算法的优缺点

优点:产生的分类规则易于理解,准确率较高。

缺点:在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效。此外,C4.5只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。

汽车评估系统决策树的建立

本文所用数据来自某汽车评估系统的一部分,下载地址:http://archive.ics.uci.edu/ml/machine-learning-databases/car/

训练集的数据表示法

变量

变量取值域

变量含义

buying

vhigh, high, med, low

购买价格

maint

vhigh, high, med, low.

维修价格

doors

2, 3, 4, 5more

门有多少

persons

2, 4, more

载人数

lug_boot

small, med, big

载行李能力

safety

low, med, high

安全性

将数据可以分为四类:unacc, acc, good, vgood 

部分训练数据集

buying

maint

doors

persons

lug_boot

safety

class

med

med

5more

more

small

low

unacc

med

med

5more

more

small

med

acc

med

med

5more

more

small

high

acc

med

med

5more

more

med

low

unacc

med

med

5more

more

med

med

acc

med

med

5more

more

med

high

vgood

med

med

5more

more

big

low

unacc

med

med

5more

more

big

med

acc

med

med

5more

more

big

high

vgood

med

low

2

2

med

low

unacc

med

low

2

2

med

med

unacc

med

low

2

2

med

high

unacc

med

low

2

2

big

high

unacc

med

low

2

4

small

low

unacc

med

low

2

4

small

med

acc

med

low

2

4

small

high

good

med

low

2

4

med

low

unacc

med

low

2

4

med

med

acc

med

low

2

4

med

high

good

med

low

2

4

big

low

unacc

med

low

2

4

big

med

good

med

low

2

4

big

high

vgood

核心代码:

 

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 struct Node
  4 {
  5     string attribute;
  6     string attribute_value;
  7     vector<Node*> child;
  8     Node()
  9     {
 10         attribute = "";
 11         attribute_value="";
 12     }
 13 };
 14 Node * root = new Node();
 15 vector<string> item;
 16 map<string,vector<string> >item_range;
 17 vector<vector<string> > states;
 18 int item_num = 0;
 19 void Input()
 20 {
 21 
 22     ifstream myfile("C:/Users/Administrator/Desktop/data4.csv");
 23     if(!myfile)
 24     {
 25         cout<<"unalbe to open myfile";
 26         exit(1);
 27     }
 28     char buff[1000];
 29     myfile.getline(buff,1000);
 30     string temp = "";
 31     int bufflen = strlen(buff);
 32     for(int i = 0; i <= bufflen; i++)
 33     {
 34         if(buff[i] == ',' || i == bufflen)
 35         {
 36             item.push_back(temp);
 37             item_num ++;
 38             temp ="";
 39         }
 40         else
 41             temp +=buff[i];
 42     }
 43     while(!myfile.eof())
 44     {
 45         vector<string> row;
 46         myfile.getline(buff,1000);
 47         int bufflen = strlen(buff);
 48         for(int i = 0; i <= bufflen; i++)
 49         {
 50             if(buff[i] == ',' || i == bufflen)
 51             {
 52                 row.push_back(temp);
 53                 temp ="";
 54             }
 55             else
 56                 temp +=buff[i];
 57         }
 58         states.push_back(row);
 59     }
 60 }
 61 void computeAttributeRange()
 62 {
 63     int states_num = states.size();
 64     for(int i = 1; i < item_num; i++)
 65     {
 66         vector<string> valuetemp;
 67         vector<string>::iterator it;
 68         for(int j = 0; j < states_num; j++)
 69         {
 70             it = find(valuetemp.begin(),valuetemp.end(),states[j][i]);
 71             if(it == valuetemp.end())
 72                 valuetemp.push_back(states[j][i]);
 73         }
 74         item_range[item[i]] = valuetemp;
 75     }
 76 }
 77 
 78 double computeEntropy(vector<vector<string> > remain_states,string attribute,string attribute_value)
 79 {
 80     vector<string> lastItem = item_range[item[item_num -1]];
 81     int Size = remain_states.size();
 82     vector<string>::iterator it;
 83     int P[10],cnt,cntnum = 0;
 84     memset(P,0,sizeof(P));
 85     it = find(item.begin(),item.end(),attribute);
 86     if(it != item.end())
 87         cnt = it - item.begin();
 88     for(int i = 0; i < Size; i++)
 89     {
 90         if(remain_states[i][cnt] == attribute_value)
 91         {
 92             cntnum ++;
 93             it = find(lastItem.begin(),lastItem.end(),remain_states[i][item_num - 1]);
 94             if(it != lastItem.end())
 95                 P[it - lastItem.begin()] ++;
 96         }
 97     }
 98     double ans = 0.0;
 99     int lastItem_size = lastItem.size();
100     for(int i = 0; i < lastItem_size ; i++)
101     {
102         double temp = cntnum == 0 ? 0.0 : double(P[i])/cntnum;
103         ans -=temp != 0.0 ? temp * log(temp)/log(2.0) : 0.0;
104     }
105     return ans*cntnum;
106 }
107 double computeEntropy(vector<vector<string> > remain_states)
108 {
109     vector<string> lastItem = item_range[item[item_num -1]];
110     int Size = remain_states.size();
111     vector<string>::iterator it;
112     int P[10];
113     memset(P,0,sizeof(P));
114     for(int i = 0; i < Size; i++)
115     {
116         it = find(lastItem.begin(),lastItem.end(),remain_states[i][item_num - 1]);
117         if(it != lastItem.end())
118             P[it - lastItem.begin()] ++;
119     }
120     double ans = 0.0;
121     int lastItem_size = lastItem.size();
122     for(int i = 0; i < lastItem_size ; i++)
123     {
124 
125         double temp = Size == 0 ? 0.0 :  double(P[i])/Size;
126         ans -=temp == 0 ? 0.0 : temp * log(temp)/log(2.0);
127     }
128     return ans;
129 }
130 
131 double computeGain(vector<vector<string> > remain_states,string attribute)
132 {
133     int Size = remain_states.size();
134     vector<string> cntItem = item_range[attribute];
135     double ans = computeEntropy(remain_states);
136     int cntItem_num = cntItem.size();
137     for(int i = 0; i < cntItem_num; i++)
138     {
139         ans -= Size== 0 ? 0.0 :computeEntropy(remain_states,attribute,cntItem[i])/Size;
140     }
141     return ans;
142 }
143 
144 string allSameOfLastItem(vector<vector<string> > remain_states,bool& ok)
145 {
146     ok = true;
147     string lastItem = remain_states[0][item_num - 1];
148     for(int i = 1; i < remain_states.size(); i++)
149         if(remain_states[i][item_num-1] != lastItem)
150         {
151             ok = false;
152             break;
153         }
154     return lastItem;
155 }
156 
157 string mostCommonValue(vector<vector<string> > remain_states)
158 {
159     int p[10];
160     memset(p,0,sizeof(p));
161     vector<string> lastItems = item_range[item[item_num - 1]];
162     for(int i = 0; i < remain_states.size(); i++)
163         p[ find(lastItems.begin(),lastItems.end(),remain_states[i][item_num - 1]) - lastItems.begin()] ++;
164     int Max = 0,maxIndex = 0,lastItems_num = lastItems.size();
165     for(int i = 0; i < lastItems_num; i++)
166     {
167         if(Max < p[i])
168             p[i] = Max,maxIndex = i;
169     }
170     return lastItems[maxIndex];
171 }
172 
173 Node* BuildDecisionTree(Node * p,vector<vector<string> > remain_states,vector<string> remain_item)
174 {
175     if(p == NULL)
176     {
177         p = new Node();
178     }
179     bool Ok = true;
180     string lastItem = allSameOfLastItem(remain_states,Ok);
181     if(Ok == true)
182     {
183         p->attribute = lastItem;
184         return p;
185     }
186     if(remain_item.size() == 2)
187     {
188         p->attribute = mostCommonValue(remain_states);
189         return p;
190     }
191     double Max = computeGain(remain_states,remain_item[1]);
192     string maxAttribute = remain_item[1];
193     for(int i = 1; i < remain_item.size()- 1; i++)
194     {
195         double temp = computeGain(remain_states,remain_item[i]);
196         if(temp > Max)
197         {
198             Max = temp;
199             maxAttribute = remain_item[i];
200         }
201     }
202     p->attribute = maxAttribute;
203     vector<string> maxAttribute_range = item_range[maxAttribute];
204     int maxAttribute_rangeNum = maxAttribute_range.size();
205     int cnt = find(item.begin(),item.end(),maxAttribute) - item.begin();
206     for(int i = 0; i < maxAttribute_rangeNum; i++)
207     {
208         vector<vector<string> > newRemain_states;
209         Node* childNode = new Node();
210         childNode->attribute_value = maxAttribute_range[i];
211         for(int j = 0; j < remain_states.size(); j++)
212         {
213             if(remain_states[j][cnt] == childNode->attribute_value)
214                 newRemain_states.push_back(remain_states[j]);
215         }
216         if(newRemain_states.size() == 0)
217         {
218             childNode->attribute = mostCommonValue(remain_states);
219         }
220         else
221         {
222 
223             vector<string>::iterator it = find(remain_item.begin(),remain_item.end(),maxAttribute);
224             if(it != remain_item.end())
225                 remain_item.erase(it);
226             BuildDecisionTree(childNode,newRemain_states,remain_item);
227         }
228 
229         p->child.push_back(childNode);
230     }
231     return p;
232 }
233 void printTree(Node* p, int dep)
234 {
235     for(int i = 0; i < dep; i++)
236         printf("\t");
237     if(p->attribute_value != "")
238     {
239         cout<<p->attribute_value<<endl;
240         for(int i = 0; i <dep+1; i++)
241             printf("\t");
242     }
243         cout<<p->attribute<<endl;
244     for(vector<Node*>::iterator it = p->child.begin(); it != p->child.end(); it ++)
245         printTree(*it,dep+1);
246 
247 }
248 bool traceTree(vector<string> state,Node *p)
249 {
250     //cout<<p->attribute<<" ";
251     if(p->child.size() <= 0 )
252         {
253             //cout<<p->attribute<< " " << state[item_num-1]<<endl;
254             //cout<<(p->attribute == state[item_num - 1])<<endl;
255             //cout<<"return"<<endl;
256             return p->attribute == state[item_num - 1];
257         }
258     vector<string>::iterator it = find(item.begin(),item.end(),p->attribute);
259     if(it != item.end())
260     {
261         for(int i = 0; i < p->child.size(); i++)
262          if(p->child[i]->attribute_value == state[it-item.begin()] )
263                 return traceTree(state,p->child[i]);
264     }
265     //cout<<"erro"<<endl;
266     return 0;
267 }
268 void test()
269 {
270     int fm = 0,fz = 0;
271     ifstream testfile("C:/Users/Administrator/Desktop/data5.csv");
272     char buff[1000];
273     testfile.getline(buff,1000);
274     string temp = "";
275     while(!testfile.eof())
276     {
277         vector<string> row;
278         testfile.getline(buff,1000);
279         int bufflen = strlen(buff);
280         for(int i = 0; i <= bufflen; i++)
281         {
282             if(buff[i] == ',' || i == bufflen)
283             {
284                 row.push_back(temp);
285                 temp ="";
286             }
287             else
288                 temp +=buff[i];
289         }
290             if(traceTree(row,root))
291                 fz++;
292         fm++;
293 
294     }
295     cout<<fz<<" "<<fm<<endl;
296     cout<<double(fz)/double(fm)<<endl;
297 }
298 
299 int main()
300 {
301     freopen("C:/Users/Administrator/Desktop/res1.txt", "w", stdout);
302     Input();
303     computeAttributeRange();
304     BuildDecisionTree(root,states,item);
305     printTree(root,0);
306     test();
307     return 0;
308 }
View Code

 

通过随机森林提高准确率

决策树是建立在已知的历史数据及概率上的,一课决策树的预测可能会不太准确,提高准确率最好的方法是构建随机森林(Random Forest)。所谓随机森林就是通过随机抽样的方式从历史数据表中生成多张抽样的历史表,对每个抽样的历史表生成一棵决策树。由于每次生成抽样表后数据都会放回到总表中,因此每一棵决策树之间都是独立的没有关联。将多颗决策树组成一个随机森林。当有一条新的数据产生时,让森林里的每一颗决策树分别进行判断,以投票最多的结果作为最终的判断结果。以此来提高正确的概率。

 

推荐阅读