首页 > 技术文章 > python实现一个朴素贝叶斯分类方法

little-horse 2017-08-07 12:29 原文

1.公式

上式中左边D是需要预测的测试数据属性,h是需要预测的类;右边式子分子是属性的条件概率和类别的先验概率,可以从统计训练数据中得到,分母对于所有实例都一样,可以不考虑,所有只需 ,返回最大概率的那个类别。但是如果测试数据中没有那个属性,整个预测概率会是0;此外,此式针对离散型属性进行训练,针对连续的数值型属性可以考虑分段,也可以假设其满足某种分布,比如正态分布,利用概率密度函数求概率。

2.部分改进

(1).针对测试数据中没有那个属性,可以平滑一下,比如下(针对非数值型属性):

上式中n是某个类别下的实例数,nc是此类别下的属性个数,m是此属性的取值个数,p是此属性取值出现的概率。比如一个属性:性别,取值男或女,则 m=2,p=1/2。

(2).针对连续的数值型属性,可以分段比如年龄0-10为A,10-30为B等;还可以假设它服从高斯分布(正态分布),利分布函数计算概率:

其中uij是某列数值型属性的均值,Qij是某列数值型属性样本标准差,Xi是数值属性。训练的时候只需要统计均值,样本标准差就行了,预测的时候利用。

3.python实现

  1 #!/usr/bin/python
  2 # -*- coding: utf-8 -*-
  3 
  4 import codecs
  5 import math
  6 
  7 class BayesClassifier:
  8 
  9     def __init__(self,dataFormat):
 10         self.prior = {}#类别的先验概率
 11         self.conditional = {}#属性的条件概率
 12         # 输入的数据格式,attr表示非数值型属性,num表示数值型属性,class表示类别
 13         self.format=dataFormat.strip().split('\t')
 14 
 15     #读取数据
 16     def readData(self,dataFile):
 17         total = 0#所有实例数
 18         self.classes = {}#统计类别
 19         self.counts = {}#用来统计
 20         totals={}#统计数值型每列的和
 21         numericValues={}#数值型每列值
 22 
 23         with codecs.open(dataFile,'r','utf-8') as f:
 24             for line in f:
 25                 fields=line.strip().split('\t')
 26                 fieldSize=len(fields)
 27                 vector=[]
 28                 nums=[]
 29                 for i in range(fieldSize):
 30                     if self.format[i]=='num':
 31                         nums.append(float(fields[i]))
 32                     elif self.format[i]=='attr':
 33                         vector.append(fields[i])
 34                     elif self.format[i]=='class':
 35                         category=fields[i]
 36                 total+=1
 37                 self.classes.setdefault(category,0)
 38                 self.counts.setdefault(category,{})
 39                 totals.setdefault(category,{})
 40                 numericValues.setdefault(category,{})
 41                 self.classes[category]+=1
 42                 #统计一条非数值型实例的属性
 43                 col=0
 44                 for columnValue in vector:
 45                     col+=1
 46                     self.counts[category].setdefault(col,{})
 47                     self.counts[category][col].setdefault(columnValue,0)
 48                     self.counts[category][col][columnValue]+=1
 49                 col=0
 50                 for columnValue in nums:
 51                     col+=1
 52                     totals[category].setdefault(col,0)
 53                     totals[category][col]+=columnValue
 54                     numericValues[category].setdefault(col,[])
 55                     numericValues[category][col].append(columnValue)
 56 
 57         #以上统计完成,计算类别先验概率和属性条件概率
 58         #计算类的先验概率=此类的实例数/总的实例数
 59         for category,count in self.classes.items():
 60             self.prior[category]=count/total
 61         #计算属性的条件概率=此类中属性数/此类实例数
 62         for category,columns in self.counts.items():
 63             self.conditional.setdefault(category,{})
 64             for col,valueCounts in columns.items():
 65                 self.conditional[category].setdefault(col,{})
 66                 colSize=len(valueCounts)#这一列属性的取值个数(如性别取值为男和女,则colSize=2)
 67                 for attr,count in valueCounts.items():
 68                     #平滑一下
 69                     self.conditional[category][col][attr]=(count+colSize*1/colSize)/(self.classes[category]+colSize)
 70         #在数值型列中计算均值和样本标准差
 71         #每列的均值
 72         self.means={}
 73         self.totals=totals
 74         for category,columns in totals.items():
 75             self.means.setdefault(category,{})
 76             for col,colSum  in columns.items():
 77                 self.means[category][col]=colSum/self.classes[category]
 78         #每列的标准差
 79         self.std={}
 80         for category,columns in numericValues.items():
 81             self.std.setdefault(category,{})
 82             for col,values in columns.items():
 83                 ssd=0
 84                 mean=self.means[category][col]
 85                 for value in values:
 86                     ssd+=(value-mean)**2
 87                 self.std[category][col]=math.sqrt(ssd/(self.classes[category]-1))
 88 
 89 
 90     #分类,返回分类结果
 91     def classify(self,itemVector):
 92         results=[]
 93         for category,prior in self.prior.items():
 94             prob=prior
 95             col=1
 96             for attrValue in itemVector:
 97                 if self.format[col]=='attr':
 98                     # 如果预测数据没有这个属性,则平滑一下,不是返回0(返回0会导致整个预测结果为0)
 99                     if not attrValue in self.conditional[category][col]:
100                         colSize=len(self.counts[category][col])
101                         prob=prob*(0+colSize*1/colSize)/(self.classes[category]+colSize)
102                     else:
103                         prob=prob*self.conditional[category][col][attrValue]
104                 #针对数值型,我们先得到该列均值与样本标准差,利用正态分布得到概率(假设该列数值满足正态分布)
105                 elif self.format[col]=='num':
106                     mean=self.means[category][col]
107                     std=self.std[category][col]
108                     prob=prob*self.gaussian(mean,std,attrValue)
109                 col+=1
110             results.append((prob,category))
111         return max(results)[1]
112 
113     #高斯分布
114     def gaussian(self,mean,std,x):
115         sqrt2pi = math.sqrt(2 * math.pi)
116         ePart=math.pow(math.e,-(x-mean)**2/(2*std**2))
117         prob=(1.0/sqrt2pi*std)*ePart
118         return prob
119 
120     # 十折验证读取数据,prefix为文件名前缀,i作为测试集编号
121     def tenFoldReadData(self,prefix,testNumber):
122         total = 0  # 所有实例数
123         self.classes = {}  # 统计类别
124         self.counts = {}  # 用来统计
125         totals = {}  # 统计数值型每列的和
126         numericValues = {}  # 数值型每列值
127 
128         for i in range(1,11):
129             if i!=testNumber:
130                 filename='%s-%02s' % (prefix,i)
131                 with codecs.open(filename, 'r', 'utf-8') as f:
132                     for line in f:
133                         fields = line.strip().split('\t')
134                         fieldSize = len(fields)
135                         vector = []
136                         nums = []
137                         for i in range(fieldSize):
138                             if self.format[i] == 'num':
139                                 nums.append(float(fields[i]))
140                             elif self.format[i] == 'attr':
141                                 vector.append(fields[i])
142                             elif self.format[i] == 'class':
143                                 category = fields[i]
144                         total += 1
145                         self.classes.setdefault(category, 0)
146                         self.counts.setdefault(category, {})
147                         totals.setdefault(category, {})
148                         numericValues.setdefault(category, {})
149                         self.classes[category] += 1
150                         # 统计一条非数值型实例的属性
151                         col = 0
152                         for columnValue in vector:
153                             col += 1
154                             self.counts[category].setdefault(col, {})
155                             self.counts[category][col].setdefault(columnValue, 0)
156                             self.counts[category][col][columnValue] += 1
157                         col = 0
158                         for columnValue in nums:
159                             col += 1
160                             totals[category].setdefault(col, 0)
161                             totals[category][col] += columnValue
162                             numericValues[category].setdefault(col, [])
163                             numericValues[category][col].append(columnValue)
164 
165         # 以上统计完成,计算类别先验概率和属性条件概率
166         # 计算类的先验概率=此类的实例数/总的实例数
167         for category, count in self.classes.items():
168             self.prior[category] = count / total
169         # 计算属性的条件概率=此类中属性数/此类实例数
170         for category, columns in self.counts.items():
171             self.conditional.setdefault(category, {})
172             for col, valueCounts in columns.items():
173                 self.conditional[category].setdefault(col, {})
174                 colSize = len(valueCounts)  # 这一列属性的取值个数(如性别取值为男和女,则colSize=2)
175                 for attr, count in valueCounts.items():
176                     # 平滑一下
177                     self.conditional[category][col][attr] = (count + colSize * 1 / colSize) / (
178                     self.classes[category] + colSize)
179         # 在数值型列中计算均值和样本标准差
180         # 每列的均值
181         self.means = {}
182         self.totals = totals
183         for category, columns in totals.items():
184             self.means.setdefault(category, {})
185             for col, colSum in columns.items():
186                 self.means[category][col] = colSum / self.classes[category]
187         # 每列的标准差
188         self.std = {}
189         for category, columns in numericValues.items():
190             self.std.setdefault(category, {})
191             for col, values in columns.items():
192                 ssd = 0
193                 mean = self.means[category][col]
194                 for value in values:
195                     ssd += (value - mean) ** 2
196                 self.std[category][col] = math.sqrt(ssd / (self.classes[category] - 1))
197 
198     #利用十折交叉验证,测试一个桶中的数据,prefix为统计文件名前缀,testNumber为要测试的一个桶中的数据
199     def testOneBucket(self,prefix,testNumber):
200         filename='%s-%02i' % (prefix,testNumber)
201         totals={}
202         with codecs.open(filename,'r','utf-8') as f:
203             for line in f:
204                 data=line.strip().split('\t')
205                 itemVector=[]
206                 classInColumn=-1
207                 for i in range(len(self.format)):
208                     if self.format[i]=='num':
209                         itemVector.append(float(data[i]))
210                     elif self.format[i]=='attr':
211                         itemVector.append(data[i])
212                     elif self.format[i]=='class':
213                         classInColumn=i
214                 realClass=data[classInColumn]#真实的类
215                 classifiedClass=self.classify(itemVector)#预测的类
216                 totals.setdefault(realClass,{})
217                 totals[realClass].setdefault(classifiedClass,0)
218                 totals[realClass][classifiedClass]+=1
219         return totals
220 
221 #十折交叉验证,prefix为十个文件名字的前缀,dataForamt为数据格式
222 def tenfold(prefix,dataFormat):
223     results={}
224     for i in range(1,11):
225         classify=BayesClassifier(dataFormat)
226         classify.tenFoldReadData(prefix,i)
227         totals=classify.testOneBucket(prefix,i)
228         for key,value in totals.items():
229             results.setdefault(key,{})
230             for ckey,cvalue in value.items():
231                 results[key].setdefault(ckey,0)
232                 results[key][ckey]+=cvalue
233     #结果展示
234     classes=list(results.keys())
235     classes.sort()
236     print(      '\n                 classes as: ')
237     header='                '
238     subheader='               +'
239     for cls in classes:
240         header+='%  10s '% cls
241         subheader+='--------+'
242     print(header)
243     print(subheader)
244     total=0.0
245     correct=0.0
246     for cls in classes:
247         row=' %10s   |' % cls
248         for c2 in classes:
249             if c2 in results[cls]:
250                 count=results[cls][c2]
251             else:
252                 count=0
253             row+=' %5i |' % count
254             total+=count
255             if c2==cls:
256                 correct+=count
257         print(row)
258     print(subheader)
259     print('\n%5.3f 正确率' % ((correct*100/total)))
260     print('总共 %i 实例'% total)
261 
262 if __name__=='__main__':
263     #classify=BayesClassifier('num,num,num,num,num,num,num,num,class')
264     #classify.readData('dataFile')
265     #print(classify.classify([2,120,54,0,0,26.8,0.455,27]))
266     tenfold('dataFilePrefix','num,num,num,num,num,num,num,num,class')#十折交叉验证

 

推荐阅读