數據描述
每條數據項儲存在列表中,最后一列儲存結果
多條數據項形成數據集
data=[[d1,d2,d3...dn,result], [d1,d2,d3...dn,result], . . [d1,d2,d3...dn,result]]
決策樹數據結構
class DecisionNode: '''決策樹節點 ''' def __init__(self,col=-1,value=None,results=None,tb=None,fb=None): '''初始化決策樹節點 args: col -- 按數據集的col列劃分數據集 value -- 以value作為劃分col列的參照 result -- 只有葉子節點有,代表最終劃分出的子數據集結果統計信息。{‘結果':結果出現次數} rb,fb -- 代表左右子樹 ''' self.col=col self.value=value self.results=results self.tb=tb self.fb=fb
決策樹分類的最終結果是將數據項劃分出了若干子集,其中每個子集的結果都一樣,所以這里采用{‘結果':結果出現次數}的方式表達每個子集
def pideset(rows,column,value): '''依據數據集rows的column列的值,判斷其與參考值value的關系對數據集進行拆分 返回兩個數據集 ''' split_function=None #value是數值類型 if isinstance(value,int) or isinstance(value,float): #定義lambda函數當row[column]>=value時返回true split_function=lambda row:row[column]>=value #value是字符類型 else: #定義lambda函數當row[column]==value時返回true split_function=lambda row:row[column]==value #將數據集拆分成兩個 set1=[row for row in rows if split_function(row)] set2=[row for row in rows if not split_function(row)] #返回兩個數據集 return (set1,set2) def uniquecounts(rows): '''計算數據集rows中有幾種最終結果,計算結果出現次數,返回一個字典 ''' results={} for row in rows: r=row[len(row)-1] if r not in results: results[r]=0 results[r]+=1 return results def giniimpurity(rows): '''返回rows數據集的基尼不純度 ''' total=len(rows) counts=uniquecounts(rows) imp=0 for k1 in counts: p1=float(counts[k1])/total for k2 in counts: if k1==k2: continue p2=float(counts[k2])/total imp+=p1*p2 return imp def entropy(rows): '''返回rows數據集的熵 ''' from math import log log2=lambda x:log(x)/log(2) results=uniquecounts(rows) ent=0.0 for r in results.keys(): p=float(results[r])/len(rows) ent=ent-p*log2(p) return ent def build_tree(rows,scoref=entropy): '''構造決策樹 ''' if len(rows)==0: return DecisionNode() current_score=scoref(rows) # 最佳信息增益 best_gain=0.0 # best_criteria=None #最佳劃分 best_sets=None column_count=len(rows[0])-1 #遍歷數據集的列,確定分割順序 for col in range(0,column_count): column_values={} # 構造字典 for row in rows: column_values[row[col]]=1 for value in column_values.keys(): (set1,set2)=pideset(rows,col,value) p=float(len(set1))/len(rows) # 計算信息增益 gain=current_score-p*scoref(set1)-(1-p)*scoref(set2) if gain>best_gain and len(set1)>0 and len(set2)>0: best_gain=gain best_criteria=(col,value) best_sets=(set1,set2) # 如果劃分的兩個數據集熵小于原數據集,進一步劃分它們 if best_gain>0: trueBranch=build_tree(best_sets[0]) falseBranch=build_tree(best_sets[1]) return DecisionNode(col=best_criteria[0],value=best_criteria[1], tb=trueBranch,fb=falseBranch) # 如果劃分的兩個數據集熵不小于原數據集,停止劃分 else: return DecisionNode(results=uniquecounts(rows)) def print_tree(tree,indent=''): if tree.results!=None: print(str(tree.results)) else: print(str(tree.col)+':'+str(tree.value)+'? ') print(indent+'T->',end='') print_tree(tree.tb,indent+' ') print(indent+'F->',end='') print_tree(tree.fb,indent+' ') def getwidth(tree): if tree.tb==None and tree.fb==None: return 1 return getwidth(tree.tb)+getwidth(tree.fb) def getdepth(tree): if tree.tb==None and tree.fb==None: return 0 return max(getdepth(tree.tb),getdepth(tree.fb))+1 def drawtree(tree,jpeg='tree.jpg'): w=getwidth(tree)*100 h=getdepth(tree)*100+120 img=Image.new('RGB',(w,h),(255,255,255)) draw=ImageDraw.Draw(img) drawnode(draw,tree,w/2,20) img.save(jpeg,'JPEG') def drawnode(draw,tree,x,y): if tree.results==None: # Get the width of each branch w1=getwidth(tree.fb)*100 w2=getwidth(tree.tb)*100 # Determine the total space required by this node left=x-(w1+w2)/2 right=x+(w1+w2)/2 # Draw the condition string draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0)) # Draw links to the branches draw.line((x,y,left+w1/2,y+100),fill=(255,0,0)) draw.line((x,y,right-w2/2,y+100),fill=(255,0,0)) # Draw the branch nodes drawnode(draw,tree.fb,left+w1/2,y+100) drawnode(draw,tree.tb,right-w2/2,y+100) else: txt=' /n'.join(['%s:%d'%v for v in tree.results.items()]) draw.text((x-20,y),txt,(0,0,0))
新聞熱點
疑難解答