這篇文章主要介紹python怎么實(shí)現(xiàn)基于信息增益的決策樹歸納,文中介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們一定要看完!
專注于為中小企業(yè)提供成都網(wǎng)站制作、做網(wǎng)站服務(wù),電腦端+手機(jī)端+微信端的三站合一,更高效的管理,為中小企業(yè)峨山縣免費(fèi)做網(wǎng)站提供優(yōu)質(zhì)的服務(wù)。我們立足成都,凝聚了一批互聯(lián)網(wǎng)行業(yè)人才,有力地推動(dòng)了上千余家企業(yè)的穩(wěn)健成長,幫助中小企業(yè)通過網(wǎng)站建設(shè)實(shí)現(xiàn)規(guī)模擴(kuò)充和轉(zhuǎn)變。具體內(nèi)容如下
# -*- coding: utf-8 -*- import numpy as np import matplotlib.mlab as mlab import matplotlib.pyplot as plt from copy import copy #加載訓(xùn)練數(shù)據(jù) #文件格式:屬性標(biāo)號(hào),是否連續(xù)【yes|no】,屬性說明 attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat' attribute_file = open(attribute_file_dest) #文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat' trainning_data_file = open(trainning_data_file_dest) #文件格式:class_id,class_desc class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat' class_desc_file = open(class_desc_file_dest) root_attr_dict = {} for line in attribute_file : line = line.strip() fld_list = line.split(',') root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:]) class_dict = {} for line in class_desc_file : line = line.strip() fld_list = line.split(',') class_dict[int(fld_list[0])] = fld_list[1] trainning_data_dict = {} class_member_set_dict = {} for line in trainning_data_file : line = line.strip() fld_list = line.split(',') rec_id = int(fld_list[0]) a1 = int(fld_list[1]) a2 = int(fld_list[2]) a3 = float(fld_list[3]) c_id = int(fld_list[4]) if c_id not in class_member_set_dict : class_member_set_dict[c_id] = set() class_member_set_dict[c_id].add(rec_id) trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id) attribute_file.close() class_desc_file.close() trainning_data_file.close() class_possibility_dict = {} for c_id in class_member_set_dict : class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict) #等待分類的數(shù)據(jù) data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat' data_to_classify_file = open(data_to_classify_file_dest) data_to_classify_dict = {} for line in data_to_classify_file : line = line.strip() fld_list = line.split(',') rec_id = int(fld_list[0]) a1 = int(fld_list[1]) a2 = int(fld_list[2]) a3 = float(fld_list[3]) c_id = int(fld_list[4]) data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id) data_to_classify_file.close() ''' 決策樹的表達(dá) 結(jié)點(diǎn)的需求: 1、指示出是哪一種分區(qū) 一共3種 一是離散窮舉 二是連續(xù)有分裂點(diǎn) 三是離散有判別集合 零是葉子結(jié)點(diǎn) 2、保存分類所需信息 3、子結(jié)點(diǎn)列表 每個(gè)結(jié)點(diǎn)用Tuple類型表示 元素一是整形,取值123 分別對應(yīng)兩種分裂類型 元素二是集合類型 對于1保存所有的離散值 對于2保存分裂點(diǎn) 對于3保存判別集合 對于0保存分類結(jié)果類標(biāo)號(hào) 元素三是dict key對于1來說是某個(gè)的離散值 對于23來說只有12兩種 對于2來說1代表小于等于分裂點(diǎn) 對于3來說1代表屬于判別集合 ''' #對于一個(gè)成員列表,計(jì)算其熵 #公式為 Info_D = - sum(pi * log2 (pi)) pi為一個(gè)元素屬于Ci的概率,用|Ci|/|D|計(jì)算 ,對所有分類求和 def get_entropy( member_list ) : #成員總數(shù) mem_cnt = len(member_list) #首先找出member中所包含的分類 class_dict = {} for mem_id in member_list : c_id = trainning_data_dict[mem_id][3] if c_id not in class_dict : class_dict[c_id] = set() class_dict[c_id].add(mem_id) tmp_sum = 0.0 for c_id in class_dict : pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt tmp_sum += pi * mlab.log2(pi) tmp_sum = -tmp_sum return tmp_sum def attribute_selection_method( member_list , attribute_dict ) : #先計(jì)算原始的熵 info_D = get_entropy(member_list) max_info_Gain = 0.0 attr_get = 0 split_point = 0.0 for attr_id in attribute_dict : #對于每一個(gè)屬性計(jì)算劃分后的熵 #信息增益等于原始的熵減去劃分后的熵 info_D_new = 0 #如果是連續(xù)屬性 if attribute_dict[attr_id][0] == 'yes' : #先得到memberlist中此屬性的取值序列,把序列中每一對相鄰項(xiàng)的中值作為劃分點(diǎn)計(jì)算熵 #找出其中最小的,作為此連續(xù)屬性的劃分點(diǎn) value_list = [] for mem_id in member_list : value_list.append(trainning_data_dict[mem_id][attr_id - 1]) #獲取相鄰元素的中值序列 mid_value_list = [] value_list.sort() #print value_list last_value = None for value in value_list : if value == last_value : continue if last_value is not None : mid_value_list.append((last_value+value)/2) last_value = value #print mid_value_list #對于中值序列做循環(huán) #計(jì)算以此值做為劃分點(diǎn)的熵 #總的熵等于兩個(gè)劃分的熵乘以兩個(gè)劃分的比重 min_info = 1000000000.0 total_mens = len(member_list) + 0.0 for mid_value in mid_value_list : #小于mid_value的mem less_list = [] #大于 more_list = [] for tmp_mem_id in member_list : if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value : less_list.append(tmp_mem_id) else : more_list.append(tmp_mem_id) sum_info = len(less_list)/total_mens * get_entropy(less_list) \ + len(more_list)/total_mens * get_entropy(more_list) if sum_info < min_info : min_info = sum_info split_point = mid_value info_D_new = min_info #如果是離散屬性 else : #計(jì)算劃分后的熵 #采用循環(huán)累加的方式 attr_value_member_dict = {} #鍵為attribute value , 值為memberlist for tmp_mem_id in member_list : attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1] if attr_value not in attr_value_member_dict : attr_value_member_dict[attr_value] = [] attr_value_member_dict[attr_value].append(tmp_mem_id) #將每個(gè)離散值的熵乘以比重加到這上面 total_mens = len(member_list) + 0.0 sum_info = 0.0 for a_value in attr_value_member_dict : sum_info += len(attr_value_member_dict[a_value])/total_mens \ * get_entropy(attr_value_member_dict[a_value]) info_D_new = sum_info info_Gain = info_D - info_D_new if info_Gain > max_info_Gain : max_info_Gain = info_Gain attr_get = attr_id #如果是離散的 #print 'attr_get ' + str(attr_get) if attribute_dict[attr_get][0] == 'no' : return (1 , attr_get , split_point) else : return (2 , attr_get , split_point) #第三類先不考慮 def get_decision_tree(father_node , key , member_list , attr_dict ) : #最終的結(jié)果是新建一個(gè)結(jié)點(diǎn),并且添加到father_node的sub_node_dict,對key為鍵 #檢查memberlist 如果都是同類的,則生成一個(gè)葉子結(jié)點(diǎn),set里面保存類標(biāo)號(hào) class_set = set() for mem_id in member_list : class_set.add(trainning_data_dict[mem_id][3]) if len(class_set) == 1 : father_node[2][key] = (0 , (1 , class_set) , {} ) return #檢查attribute_list,如果為空,產(chǎn)生葉子結(jié)點(diǎn),類標(biāo)號(hào)為memberlist中多數(shù)元素的類標(biāo)號(hào) #如果幾個(gè)類的成員等量,則打印提示,并且全部添加到set里面 if not attr_dict : class_cnt_dict = {} for mem_id in member_list : c_id = trainning_data_dict[mem_id][3] if c_id not in class_cnt_dict : class_cnt_dict[c_id] = 1 else : class_cnt_dict[c_id] += 1 class_set = set() max_cnt = 0 for c_id in class_cnt_dict : if class_cnt_dict[c_id] > max_cnt : max_cnt = class_cnt_dict[c_id] class_set.clear() class_set.add(c_id) elif class_cnt_dict[c_id] == max_cnt : class_set.add(c_id) if len(class_set) > 1 : print 'more than one class !' father_node[2][key] = (0 , (1 , class_set ) , {} ) return #找出最好的分區(qū)方案 , 暫不考慮第三種劃分方法 #比較所有離散屬性和所有連續(xù)屬性的所有中值點(diǎn)劃分的信息增益 split_criterion = attribute_selection_method(member_list , attr_dict) #print split_criterion selected_plan_id = split_criterion[0] selected_attr_id = split_criterion[1] #如果采用的是離散屬性做為分區(qū)方案,刪除這個(gè)屬性 new_attr_dict = copy(attr_dict) if attr_dict[selected_attr_id][0] == 'no' : del new_attr_dict[selected_attr_id] #建立一個(gè)結(jié)點(diǎn)new_node,father_node[2][key] = new_node #然后對new node的每一個(gè)key , sub_member_list, #調(diào)用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict) #實(shí)現(xiàn)遞歸 ele2 = ( selected_attr_id , set() ) #如果是1 , ele2保存所有離散值 if selected_plan_id == 1 : for mem_id in member_list : ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1]) #如果是2,ele2保存分裂點(diǎn) elif selected_plan_id == 2 : ele2[1].add(split_criterion[2]) #如果是3則保存判別集合,先不管 else : print 'not completed' pass new_node = ( selected_plan_id , ele2 , {} ) father_node[2][key] = new_node #生成KEY,并遞歸調(diào)用 if selected_plan_id == 1 : #每個(gè)attr_value是一個(gè)key attr_value_member_dict = {} for mem_id in member_list : attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ] if attr_value not in attr_value_member_dict : attr_value_member_dict[attr_value] = [] attr_value_member_dict[attr_value].append(mem_id) for attr_value in attr_value_member_dict : get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict) pass elif selected_plan_id == 2 : #key 只有12 , 小于等于分裂點(diǎn)的是1 , 大于的是2 less_list = [] more_list = [] for mem_id in member_list : attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ] if attr_value <= split_criterion[2] : less_list.append(mem_id) else : more_list.append(mem_id) #if len(less_list) != 0 : get_decision_tree(new_node , 1 , less_list , new_attr_dict) #if len(more_list) != 0 : get_decision_tree(new_node , 2 , more_list , new_attr_dict) pass #如果是3則保存判別集合,先不管 else : print 'not completed' pass def get_class_sub(node , tp ) : # attr_id = node[1][0] plan_id = node[0] key = 0 if plan_id == 0 : return node[1][1] elif plan_id == 1 : key = tp[attr_id - 1] elif plan_id == 2 : split_point = tuple(node[1][1])[0] attr_value = tp[attr_id - 1] if attr_value <= split_point : key = 1 else : key = 2 else : print 'error' return set() return get_class_sub(node[2][key] , tp ) def get_class(r_node , tp) : #tp為一組屬性值 if r_node[0] != -1 : print 'error' return set() if 1 in r_node[2] : return get_class_sub(r_node[2][1] , tp) else : print 'error' return set() if __name__ == '__main__' : root_node = ( -1 , set() , {} ) mem_list = trainning_data_dict.keys() get_decision_tree(root_node , 1 , mem_list , root_attr_dict ) #測試分類器的準(zhǔn)確率 diff_cnt = 0 for mem_id in data_to_classify_dict : c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3]) if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] : print tuple(c_id)[0] print data_to_classify_dict[mem_id][3] print 'different' diff_cnt += 1 print diff_cnt
以上是“python怎么實(shí)現(xiàn)基于信息增益的決策樹歸納”這篇文章的所有內(nèi)容,感謝各位的閱讀!希望分享的內(nèi)容對大家有幫助,更多相關(guān)知識(shí),歡迎關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道!
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢,專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場景需求。
網(wǎng)站題目:python怎么實(shí)現(xiàn)基于信息增益的決策樹歸納-創(chuàng)新互聯(lián)
本文鏈接:http://jinyejixie.com/article20/dipdco.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供服務(wù)器托管、自適應(yīng)網(wǎng)站、網(wǎng)站營銷、靜態(tài)網(wǎng)站、品牌網(wǎng)站設(shè)計(jì)、品牌網(wǎng)站制作
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容