登录
首页精彩阅读决策分类树算法之ID3,C4.5算法系列
决策分类树算法之ID3,C4.5算法系列
2015-12-03
收藏

决策分类树算法之ID3,C4.5算法系列


一、引言

在最开始的时候,我本来准备学习的是C4.5算法,后来发现C4.5算法的核心还是ID3算法,所以又辗转回到学习ID3算法了,因为C4.5是他的一个改进。至于是什么改进,在后面的描述中我会提到。

二、ID3算法

ID3算法是一种分类决策树算法。他通过一系列的规则,将数据最后分类成决策树的形式。分类的根据是用到了熵这个概念。熵在物理这门学科中就已经出现过,表示是一个物质的稳定度,在这里就是分类的纯度的一个概念。公式为:

在ID3算法中,是采用Gain信息增益来作为一个分类的判定标准的。他的定义为:

每次选择属性中信息增益最大作为划分属性,在这里本人实现了一个java版本的ID3算法,为了模拟数据的可操作性,就把数据写到一个input.txt文件中,作为数据源,格式如下:

[java] view plaincopyprint?
  1. Day OutLook Temperature Humidity Wind PlayTennis  
  2. 1 Sunny Hot High Weak No  
  3. 2 Sunny Hot High Strong No  
  4. 3 Overcast Hot High Weak Yes  
  5. 4 Rainy Mild High Weak Yes  
  6. 5 Rainy Cool Normal Weak Yes  
  7. 6 Rainy Cool Normal Strong No  
  8. 7 Overcast Cool Normal Strong Yes  
  9. 8 Sunny Mild High Weak No  
  10. 9 Sunny Cool Normal Weak Yes  
  11. 10 Rainy Mild Normal Weak Yes  
  12. 11 Sunny Mild Normal Strong Yes  
  13. 12 Overcast Mild High Strong Yes  
  14. 13 Overcast Hot Normal Weak Yes  
  15. 14 Rainy Mild High Strong No  

PalyTennis属性为结构属性,是作为类标识用的,中间的OutLool,Temperature,Humidity,Wind才是划分属性,通过将源数据与执行程序分类,这样可以模拟巨大的数据量了。下面是ID3的主程序类,本人将ID3的算法进行了包装,对外只开放了一个构建决策树的方法,在构造函数时候,只需传入一个数据路径文件即可:

[java] view plaincopyprint?
  1. package DataMing_ID3;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.HashMap;  
  9. import java.util.Iterator;  
  10. import java.util.Map;  
  11. import java.util.Map.Entry;  
  12. import java.util.Set;  
  13.   
  14. /** 
  15.  * ID3算法实现类 
  16.  *  
  17.  * @author lyq 
  18.  *  
  19.  */  
  20. public class ID3Tool {  
  21.     // 类标号的值类型  
  22.     private final String YES = "Yes";  
  23.     private final String NO = "No";  
  24.   
  25.     // 所有属性的类型总数,在这里就是data源数据的列数  
  26.     private int attrNum;  
  27.     private String filePath;  
  28.     // 初始源数据,用一个二维字符数组存放模仿表格数据  
  29.     private String[][] data;  
  30.     // 数据的属性行的名字  
  31.     private String[] attrNames;  
  32.     // 每个属性的值所有类型  
  33.     private HashMap<String, ArrayList<String>> attrValue;  
  34.   
  35.     public ID3Tool(String filePath) {  
  36.         this.filePath = filePath;  
  37.         attrValue = new HashMap<>();  
  38.     }  
  39.   
  40.     /** 
  41.      * 从文件中读取数据 
  42.      */  
  43.     private void readDataFile() {  
  44.         File file = new File(filePath);  
  45.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  46.   
  47.         try {  
  48.             BufferedReader in = new BufferedReader(new FileReader(file));  
  49.             String str;  
  50.             String[] tempArray;  
  51.             while ((str = in.readLine()) != null) {  
  52.                 tempArray = str.split(" ");  
  53.                 dataArray.add(tempArray);  
  54.             }  
  55.             in.close();  
  56.         } catch (IOException e) {  
  57.             e.getStackTrace();  
  58.         }  
  59.   
  60.         data = new String[dataArray.size()][];  
  61.         dataArray.toArray(data);  
  62.         attrNum = data[0].length;  
  63.         attrNames = data[0];  
  64.   
  65.         /* 
  66.          * for(int i=0; i<data.length;i++){ for(int j=0; j<data[0].length; j++){ 
  67.          * System.out.print(" " + data[i][j]); } 
  68.          *  
  69.          * System.out.print("\n"); } 
  70.          */  
  71.     }  
  72.   
  73.     /** 
  74.      * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 
  75.      */  
  76.     private void initAttrValue() {  
  77.         ArrayList<String> tempValues;  
  78.   
  79.         // 按照列的方式,从左往右找  
  80.         for (int j = 1; j < attrNum; j++) {  
  81.             // 从一列中的上往下开始寻找值  
  82.             tempValues = new ArrayList<>();  
  83.             for (int i = 1; i < data.length; i++) {  
  84.                 if (!tempValues.contains(data[i][j])) {  
  85.                     // 如果这个属性的值没有添加过,则添加  
  86.                     tempValues.add(data[i][j]);  
  87.                 }  
  88.             }  
  89.   
  90.             // 一列属性的值已经遍历完毕,复制到map属性表中  
  91.             attrValue.put(data[0][j], tempValues);  
  92.         }  
  93.   
  94.         /* 
  95.          * for(Map.Entry entry : attrValue.entrySet()){ 
  96.          * System.out.println("key:value " + entry.getKey() + ":" + 
  97.          * entry.getValue()); } 
  98.          */  
  99.     }  
  100.   
  101.     /** 
  102.      * 计算数据按照不同方式划分的熵 
  103.      *  
  104.      * @param remainData 
  105.      *            剩余的数据 
  106.      * @param attrName 
  107.      *            待划分的属性,在算信息增益的时候会使用到 
  108.      * @param attrValue 
  109.      *            划分的子属性值 
  110.      * @param isParent 
  111.      *            是否分子属性划分还是原来不变的划分 
  112.      */  
  113.     private double computeEntropy(String[][] remainData, String attrName,  
  114.             String value, boolean isParent) {  
  115.         // 实例总数  
  116.         int total = 0;  
  117.         // 正实例数  
  118.         int posNum = 0;  
  119.         // 负实例数  
  120.         int negNum = 0;  
  121.   
  122.         // 还是按列从左往右遍历属性  
  123.         for (int j = 1; j < attrNames.length; j++) {  
  124.             // 找到了指定的属性  
  125.             if (attrName.equals(attrNames[j])) {  
  126.                 for (int i = 1; i < remainData.length; i++) {  
  127.                     // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤  
  128.                     if (isParent  
  129.                             || (!isParent && remainData[i][j].equals(value))) {  
  130.                         if (remainData[i][attrNames.length - 1].equals(YES)) {  
  131.                             // 判断此行数据是否为正实例  
  132.                             posNum++;  
  133.                         } else {  
  134.                             negNum++;  
  135.                         }  
  136.                     }  
  137.                 }  
  138.             }  
  139.         }  
  140.   
  141.         total = posNum + negNum;  
  142.         double posProbobly = (double) posNum / total;  
  143.         double negProbobly = (double) negNum / total;  
  144.   
  145.         if (posProbobly == 1 || posProbobly == 0) {  
  146.             // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错  
  147.             return 0;  
  148.         }  
  149.   
  150.         double entropyValue = -posProbobly * Math.log(posProbobly)  
  151.                 / Math.log(2.0) - negProbobly * Math.log(negProbobly)  
  152.                 / Math.log(2.0);  
  153.   
  154.         // 返回计算所得熵  
  155.         return entropyValue;  
  156.     }  
  157.   
  158.     /** 
  159.      * 为某个属性计算信息增益 
  160.      *  
  161.      * @param remainData 
  162.      *            剩余的数据 
  163.      * @param value 
  164.      *            待划分的属性名称 
  165.      * @return 
  166.      */  
  167.     private double computeGain(String[][] remainData, String value) {  
  168.         double gainValue = 0;  
  169.         // 源熵的大小将会与属性划分后进行比较  
  170.         double entropyOri = 0;  
  171.         // 子划分熵和  
  172.         double childEntropySum = 0;  
  173.         // 属性子类型的个数  
  174.         int childValueNum = 0;  
  175.         // 属性值的种数  
  176.         ArrayList<String> attrTypes = attrValue.get(value);  
  177.         // 子属性对应的权重比  
  178.         HashMap<String, Integer> ratioValues = new HashMap<>();  
  179.   
  180.         for (int i = 0; i < attrTypes.size(); i++) {  
  181.             // 首先都统一计数为0  
  182.             ratioValues.put(attrTypes.get(i), 0);  
  183.         }  
  184.   
  185.         // 还是按照一列,从左往右遍历  
  186.         for (int j = 1; j < attrNames.length; j++) {  
  187.             // 判断是否到了划分的属性列  
  188.             if (value.equals(attrNames[j])) {  
  189.                 for (int i = 1; i <= remainData.length - 1; i++) {  
  190.                     childValueNum = ratioValues.get(remainData[i][j]);  
  191.                     // 增加个数并且重新存入  
  192.                     childValueNum++;  
  193.                     ratioValues.put(remainData[i][j], childValueNum);  
  194.                 }  
  195.             }  
  196.         }  
  197.   
  198.         // 计算原熵的大小  
  199.         entropyOri = computeEntropy(remainData, value, null, true);  
  200.         for (int i = 0; i < attrTypes.size(); i++) {  
  201.             double ratio = (double) ratioValues.get(attrTypes.get(i))  
  202.                     / (remainData.length - 1);  
  203.             childEntropySum += ratio  
  204.                     * computeEntropy(remainData, value, attrTypes.get(i), false);  
  205.   
  206.             // System.out.println("ratio:value: " + ratio + " " +  
  207.             // computeEntropy(remainData, value,  
  208.             // attrTypes.get(i), false));  
  209.         }  
  210.   
  211.         // 二者熵相减就是信息增益  
  212.         gainValue = entropyOri - childEntropySum;  
  213.         return gainValue;  
  214.     }  
  215.   
  216.     /** 
  217.      * 计算信息增益比 
  218.      *  
  219.      * @param remainData 
  220.      *            剩余数据 
  221.      * @param value 
  222.      *            待划分属性 
  223.      * @return 
  224.      */  
  225.     private double computeGainRatio(String[][] remainData, String value) {  
  226.         double gain = 0;  
  227.         double spiltInfo = 0;  
  228.         int childValueNum = 0;  
  229.         // 属性值的种数  
  230.         ArrayList<String> attrTypes = attrValue.get(value);  
  231.         // 子属性对应的权重比  
  232.         HashMap<String, Integer> ratioValues = new HashMap<>();  
  233.   
  234.         for (int i = 0; i < attrTypes.size(); i++) {  
  235.             // 首先都统一计数为0  
  236.             ratioValues.put(attrTypes.get(i), 0);  
  237.         }  
  238.   
  239.         // 还是按照一列,从左往右遍历  
  240.         for (int j = 1; j < attrNames.length; j++) {  
  241.             // 判断是否到了划分的属性列  
  242.             if (value.equals(attrNames[j])) {  
  243.                 for (int i = 1; i <= remainData.length - 1; i++) {  
  244.                     childValueNum = ratioValues.get(remainData[i][j]);  
  245.                     // 增加个数并且重新存入  
  246.                     childValueNum++;  
  247.                     ratioValues.put(remainData[i][j], childValueNum);  
  248.                 }  
  249.             }  
  250.         }  
  251.   
  252.         // 计算信息增益  
  253.         gain = computeGain(remainData, value);  
  254.         // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):  
  255.         for (int i = 0; i < attrTypes.size(); i++) {  
  256.             double ratio = (double) ratioValues.get(attrTypes.get(i))  
  257.                     / (remainData.length - 1);  
  258.             spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);  
  259.         }  
  260.   
  261.         // 计算机信息增益率  
  262.         return gain / spiltInfo;  
  263.     }  
  264.   
  265.     /** 
  266.      * 利用源数据构造决策树 
  267.      */  
  268.     private void buildDecisionTree(AttrNode node, String parentAttrValue,  
  269.             String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {  
  270.         node.setParentAttrValue(parentAttrValue);  
  271.   
  272.         String attrName = "";  
  273.         double gainValue = 0;  
  274.         double tempValue = 0;  
  275.   
  276.         // 如果只有1个属性则直接返回  
  277.         if (remainAttr.size() == 1) {  
  278.             System.out.println("attr null");  
  279.             return;  
  280.         }  
  281.   
  282.         // 选择剩余属性中信息增益最大的作为下一个分类的属性  
  283.         for (int i = 0; i < remainAttr.size(); i++) {  
  284.             // 判断是否用ID3算法还是C4.5算法  
  285.             if (isID3) {  
  286.                 // ID3算法采用的是按照信息增益的值来比  
  287.                 tempValue = computeGain(remainData, remainAttr.get(i));  
  288.             } else {  
  289.                 // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足  
  290.                 tempValue = computeGainRatio(remainData, remainAttr.get(i));  
  291.             }  
  292.   
  293.             if (tempValue > gainValue) {  
  294.                 gainValue = tempValue;  
  295.                 attrName = remainAttr.get(i);  
  296.             }  
  297.         }  
  298.   
  299.         node.setAttrName(attrName);  
  300.         ArrayList<String> valueTypes = attrValue.get(attrName);  
  301.         remainAttr.remove(attrName);  
  302.   
  303.         AttrNode[] childNode = new AttrNode[valueTypes.size()];  
  304.         String[][] rData;  
  305.         for (int i = 0; i < valueTypes.size(); i++) {  
  306.             // 移除非此值类型的数据  
  307.             rData = removeData(remainData, attrName, valueTypes.get(i));  
  308.   
  309.             childNode[i] = new AttrNode();  
  310.             boolean sameClass = true;  
  311.             ArrayList<String> indexArray = new ArrayList<>();  
  312.             for (int k = 1; k < rData.length; k++) {  
  313.                 indexArray.add(rData[k][0]);  
  314.                 // 判断是否为同一类的  
  315.                 if (!rData[k][attrNames.length - 1]  
  316.                         .equals(rData[1][attrNames.length - 1])) {  
  317.                     // 只要有1个不相等,就不是同类型的  
  318.                     sameClass = false;  
  319.                     break;  
  320.                 }  
  321.             }  
  322.   
  323.             if (!sameClass) {  
  324.                 // 创建新的对象属性,对象的同个引用会出错  
  325.                 ArrayList<String> rAttr = new ArrayList<>();  
  326.                 for (String str : remainAttr) {  
  327.                     rAttr.add(str);  
  328.                 }  
  329.   
  330.                 buildDecisionTree(childNode[i], valueTypes.get(i), rData,  
  331.                         rAttr, isID3);  
  332.             } else {  
  333.                 // 如果是同种类型,则直接为数据节点  
  334.                 childNode[i].setParentAttrValue(valueTypes.get(i));  
  335.                 childNode[i].setChildDataIndex(indexArray);  
  336.             }  
  337.   
  338.         }  
  339.         node.setChildAttrNode(childNode);  
  340.     }  
  341.   
  342.     /** 
  343.      * 属性划分完毕,进行数据的移除 
  344.      *  
  345.      * @param srcData 
  346.      *            源数据 
  347.      * @param attrName 
  348.      *            划分的属性名称 
  349.      * @param valueType 
  350.      *            属性的值类型 
  351.      */  
  352.     private String[][] removeData(String[][] srcData, String attrName,  
  353.             String valueType) {  
  354.         String[][] desDataArray;  
  355.         ArrayList<String[]> desData = new ArrayList<>();  
  356.         // 待删除数据  
  357.         ArrayList<String[]> selectData = new ArrayList<>();  
  358.         selectData.add(attrNames);  
  359.   
  360.         // 数组数据转化到列表中,方便移除  
  361.         for (int i = 0; i < srcData.length; i++) {  
  362.             desData.add(srcData[i]);  
  363.         }  
  364.   
  365.         // 还是从左往右一列列的查找  
  366.         for (int j = 1; j < attrNames.length; j++) {  
  367.             if (attrNames[j].equals(attrName)) {  
  368.                 for (int i = 1; i < desData.size(); i++) {  
  369.                     if (desData.get(i)[j].equals(valueType)) {  
  370.                         // 如果匹配这个数据,则移除其他的数据  
  371.                         selectData.add(desData.get(i));  
  372.                     }  
  373.                 }  
  374.             }  
  375.         }  
  376.   

数据分析咨询请扫描二维码

客服在线
立即咨询