朴素贝叶斯算法的Python实现

发布时间:2019-09-02 07:41:55编辑:auto阅读(1672)

    注意:1、代码中的注释请不要放在源程序中运行,会报错。

        2、代码中的数据集来源于http://archive.ics.uci.edu/ml/datasets/Car+Evaluation

        3、对于朴素贝叶斯的原理,可以查看我的前面的博客

    # Author :Wenxiang Cui
    # Date :2015/9/11
    # Function: A classifier which using naive Bayesian algorithm 
    
    import math
    
    class Bayesian:
    	def __init__(self):
    		self.dataS = [] # 训练样本集DataSource
    		self.attriList = [] # 属性集合
    		self.desClass = 0 # 分类目标属性在attriList中的位置
    	def loadDataS(self,fileName,decollator):
    		#input: 
    		#		fileName - DataSource 的文件名
    		#		decollator - DataSource 中每个字段之间的分割符,有可能是空格或','
    		#function : 
    		#		从磁盘中读取数据并转化为较好处理的列表	
    		items = []
    		fp = open(filename,'r')
    		lines = fp.readlines()
    		for line in lines:
    			line = line.strip('\n')
    			items.append(line)
    		fp.close()
    
    		i = 0
    		b = []
    		for i in range(len(items)):
    			b.append(items[i].split(decollator))
    		self.dataS = b[:]
    	def getAttriList(self,attributes):
    		#input: 
    		#		attributes - 训练数据集中的属性集合,必须与dataSource中的列相对应
    		#function: 
    		#		获得训练数据集的属性列表
    		self.attriList = attributes[:]
    	def getDesClass(self,loca):
    		#input: 
    		#		loca - 分类目标属性在attriList中的位置
    		#function: 
    		#		获得分类目标属性在attriList中的位置
    		self.desClass = loca
    	def calPriorProb(self):
    		#input: 
    		#		
    		#function: 
    		#		计算类的先验概率
    		dictFreq = {} # 构建频度表,用字典表示
    		desLabel = [] 
    		sampleNum = 0
    		for items  in self.dataS:
    			sampleNum += 1
    			if not items[self.desClass] in dictFreq:
    				dictFreq[items[self.desClass]] = 1
    				desLabel.append(items[self.desClass])
    			else:
    				dictFreq[items[self.desClass]] += 1
    		dictPriorP = {} # 构建先验概率表,用字典表示
    		for item in desLabel:
    			dictPriorP[item] = float(dictFreq[item]) / sampleNum
    		self.PriorP = dictPriorP[:]
    		self.classLabel = desLabel[:]
    	def calProb(self,type,loca):
    		#input: 
    		#		type - 定义属性是连续的还是离散的
    		#   	loca - 该属性在属性集中的位置
    		#output:
    		#		dictPara - 连续属性的样本均值和方差(列表表示)
    		#		dictProb - 离散属性的类条件概率
    		#function: 
    		#		计算某个属性的类条件概率密度
    		if type == 'continuous': 
    			dictData = [] # 提取出样本的类别和当前属性值
    			dictPara = [] # 记录样本的类别和其对应的样本均值和方差
    			for item in self.classLabel:
    				dictData.append([])
    				dictPara.append([])
    			for items in self.dataS:
    				dataIndex = self.classLabel.index(items[self.desLabel]) # 返回当前样本类属性
    				dictData[dataIndex].append(float(items[loca])) # 记录当前属性值及该样本的类属性
    			#计算类属性的样本均值和方差(可以用Numpy包来快速处理)
    			for i in range(len(self.classLabel)):
    				[a,b] = self.calParam(dictData[i])
    				dictPara[i].append(a)
    				dictPara[i].append(b)
    			return dictPara
    		elif type == 'discrete': 
    			dictFreq = {}
    			dictProb = {}
    			for item in self.classLabel:# 构建频度表,用字典表示
    				dictFreq[item] = {}		
    				dictProb[item] = {}	
    			label = []
    			for items in self.dataS:
    				if not items[loca] in label:
    					label.append(items[loca])
    					dictFreq[items[self.desClass]][items[loca]] = 1
    				else:
    					dictFreq[items[self.desClass]][items[loca]] += 1
    			needLaplace = 0
    			for key in dictFreq.keys():
    				for ch in labels:
    					if ch not in dictFreq[key]:
    						dictFreq[key][ch] = 0
    						needLaplace = 1
    				if needLaplace == 1: # 拉普拉斯平滑用于处理类条件概率为0的情况
    					dictFreq[key] = self.LaplaceEstimator(dictFreq[key])	
    					needLaplace = 0
    			for item in self.classLabel:
    				for ch in dictFreq[item]:
    					dictProb[item][ch] = float(dictFreq[item][ch]) / self.dictFreq[item]	
    			return dictProb	
    		else:
    			print 'Wrong type!'
    	def calParam(self,souList):
    		#input: 
    		#		souList - 待计算的列表
    		#output:
    		#		meanVal - 列表元素的均值
    		# 		deviation - 列表元素的标准差
    		#function: 
    		#		计算某个属性的类条件概率密度
    		meanVal = sum(souList) / float(len(souList))
    		deviation = 0
    		tempt = 0
    		for val in souList:
    			tempt += (val - meanVal)**2
    		deviation = math.sqrt(float(tempt)/(len(souList)-1))
    		return meanVal,deviation
    	def LaplaceEstimator(self,souDict):
    		#input: 
    		#		souDict - 待计算的字典
    		#output:
    		#		desDict - 平滑后的字典
    		#function: 
    		#		拉普拉斯平滑
    		desDict = souDict.copy()
    		for key in souDict:
    			desDict[key] = souDict[key] + 1
    		return desDict
    
    class CarBayesian(Bayesian):
    	def __init__(self):
    		Bayesian.__init__(self)
    		self.buying = {}
    		self.maint = {}
    		self.doors = {}
    		self.persons = {}
    		self.lug_boot = {}
    		self.safety = {}
    	def tranning(self):
    		self.Prob = []
    		self.buying = Bayesian.calProb('discrete',0)
    		self.maint = Bayesian.calProb('discrete',1)	
    		self.doors = Bayesian.calProb('discrete',2)
    		self.persons = Bayesian.calProb('discrete',3)
    		self.lug_boot = Bayesian.calProb('discrete',4)
    		self.safety = Bayesian.calProb('discrete',5)
    
    		self.Prob.append(self.buying)
    		self.Prob.append(self.maint)
    		self.Prob.append(self.doors)
    		self.Prob.append(self.persons)
    		self.Prob.append(self.lug_boot)
    		self.Prob.append(self.safety)
    	def classify(self,sample):
    		#input :
    		# 		sample - 一个样本
    		#function:
    		# 		判断输入的这个样本的类别
    		posteriorProb = {}
    		for item in self.classLabel:
    				posteriorProb[item] = self.PriorP[item]
    				for i in range(len(sample)-1):
    					posteriorProb[item] *= self.Prob[i][item][sample[i]]
    		maxVal = posteriorProb[self.classLabel[0]]
    		i = 0
    		for item in posteriorProb:
    			i += 1
    			if posteriorProb[item] > maxVal:
    				maxVal = posteriorProb[item]
    				location = i
    		print "该样本属于的类别是:",self.classLabel[location]
    
    
    filename = "D:\MyDocuments-HnH\DataMining\DataSets\Car\Car_Data.txt"
    MyCar = CarBayesian()
    MyCar.loadDataS(filename,',')
    attributes = ['buying','maint','doors','persons','lug_boot','safety']
    MyCar.getAttriList(attributes)
    MyCar.getDesClass(7-1)
    MyCar.tranning()
    sample = ['vhigh','vhigh','2','2','small','low']


关键字