小编给大家分享一下Java如何实现决策树算法,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!
具体如下:
决策树算法是一种逼近离散函数值的方法。它是一种典型的分类方法,首先对数据进行处理,利用归纳算法生成可读的规则和决策树,然后使用决策对新数据进行分析。本质上决策树是通过一系列规则对数据进行分类的过程。
决策树构造可以分两步进行。第一步,决策树的生成:由训练样本集生成决策树的过程。一般情况下,训练样本数据集是根据实际需要有历史的、有一定综合程度的,用于数据分析处理的数据集。第二步,决策树的剪枝:决策树的剪枝是对上一阶段生成的决策树进行检验、校正和修下的过程,主要是用新的样本数据集(称为测试数据集)中的数据校验决策树生成过程中产生的初步规则,将那些影响预衡准确性的分枝剪除。
java实现代码如下:
package demo;import java.util.HashMap;import java.util.LinkedList;import java.util.List;import java.util.Map;import java.util.Map.Entry;import java.util.Set;public class DicisionTree { public static void main(String[] args) throws Exception { System.out.print("编程网测试结果:"); String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT", "CREDIT_RATING" }; // 读取样本集 Map<Object, List<Sample>> samples = readSamples(attrNames); // 生成决策树 Object decisionTree = generateDecisionTree(samples, attrNames); // 输出决策树 outputDecisionTree(decisionTree, 0, null); } static Map<Object, List<Sample>> readSamples(String[] attrNames) { // 样本属性及其所属分类(数组中的最后一个元素为样本所属分类) Object[][] rawData = new Object[][] { { "<30 ", "High ", "No ", "Fair ", "0" }, { "<30 ", "High ", "No ", "Excellent", "0" }, { "30-40", "High ", "No ", "Fair ", "1" }, { ">40 ", "Medium", "No ", "Fair ", "1" }, { ">40 ", "Low ", "Yes", "Fair ", "1" }, { ">40 ", "Low ", "Yes", "Excellent", "0" }, { "30-40", "Low ", "Yes", "Excellent", "1" }, { "<30 ", "Medium", "No ", "Fair ", "0" }, { "<30 ", "Low ", "Yes", "Fair ", "1" }, { ">40 ", "Medium", "Yes", "Fair ", "1" }, { "<30 ", "Medium", "Yes", "Excellent", "1" }, { "30-40", "Medium", "No ", "Excellent", "1" }, { "30-40", "High ", "Yes", "Fair ", "1" }, { ">40 ", "Medium", "No ", "Excellent", "0" } }; // 读取样本属性及其所属分类,构造表示样本的Sample对象,并按分类划分样本集 Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>(); for (Object[] row : rawData) { Sample sample = new Sample(); int i = 0; for (int n = row.length - 1; i < n; i++) sample.setAttribute(attrNames[i], row[i]); sample.setCategory(row[i]); List<Sample> samples = ret.get(row[i]); if (samples == null) { samples = new LinkedList<Sample>(); ret.put(row[i], samples); } samples.add(sample); } return ret; } static Object generateDecisionTree( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { // 如果只有一个样本,将该样本所属分类作为新样本的分类 if (categoryToSamples.size() == 1) return categoryToSamples.keySet().iterator().next(); // 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出分类 if (attrNames.length == 0) { int max = 0; Object maxCategory = null; for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { int cur = entry.getValue().size(); if (cur > max) { max = cur; maxCategory = entry.getKey(); } } return maxCategory; } // 选取测试属性 Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames); // 决策树根结点,分支属性为选取的测试属性 Tree tree = new Tree(attrNames[(Integer) rst[0]]); // 已用过的测试属性不应再次被选为测试属性 String[] subA = new String[attrNames.length - 1]; for (int i = 0, j = 0; i < attrNames.length; i++) if (i != (Integer) rst[0]) subA[j++] = attrNames[i]; // 根据分支属性生成分支 @SuppressWarnings("unchecked") Map<Object, Map<Object, List<Sample>>> splits = (Map<Object, Map<Object, List<Sample>>>) rst[2]; for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) { Object attrValue = entry.getKey(); Map<Object, List<Sample>> split = entry.getValue(); Object child = generateDecisionTree(split, subA); tree.setChild(attrValue, child); } return tree; } static Object[] chooseBestTestAttribute( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { int minIndex = -1; // 最优属性下标 double minValue = Double.MAX_VALUE; // 最小信息量 Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最优分支方案 // 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优 for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) { int allCount = 0; // 统计样本总数的计数器 // 按当前属性构建Map:属性值->(分类->样本列表) Map<Object, Map<Object, List<Sample>>> curSplits = new HashMap<Object, Map<Object, List<Sample>>>(); for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { Object category = entry.getKey(); List<Sample> samples = entry.getValue(); for (Sample sample : samples) { Object attrValue = sample .getAttribute(attrNames[attrIndex]); Map<Object, List<Sample>> split = curSplits.get(attrValue); if (split == null) { split = new HashMap<Object, List<Sample>>(); curSplits.put(attrValue, split); } List<Sample> splitSamples = split.get(category); if (splitSamples == null) { splitSamples = new LinkedList<Sample>(); split.put(category, splitSamples); } splitSamples.add(sample); } allCount += samples.size(); } // 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和 double curValue = 0.0; // 计数器:累加各分支 for (Map<Object, List<Sample>> splits : curSplits.values()) { double perSplitCount = 0; for (List<Sample> list : splits.values()) perSplitCount += list.size(); // 累计当前分支样本数 double perSplitValue = 0.0; // 计数器:当前分支 for (List<Sample> list : splits.values()) { double p = list.size() / perSplitCount; perSplitValue -= p * (Math.log(p) / Math.log(2)); } curValue += (perSplitCount / allCount) * perSplitValue; } // 选取最小为最优 if (minValue > curValue) { minIndex = attrIndex; minValue = curValue; minSplits = curSplits; } } return new Object[] { minIndex, minValue, minSplits }; } static void outputDecisionTree(Object obj, int level, Object from) { for (int i = 0; i < level; i++) System.out.print("|-----"); if (from != null) System.out.printf("(%s):", from); if (obj instanceof Tree) { Tree tree = (Tree) obj; String attrName = tree.getAttribute(); System.out.printf("[%s = ?]\n", attrName); for (Object attrValue : tree.getAttributeValues()) { Object child = tree.getChild(attrValue); outputDecisionTree(child, level + 1, attrName + " = " + attrValue); } } else { System.out.printf("[CATEGORY = %s]\n", obj); } } static class Sample { private Map<String, Object> attributes = new HashMap<String, Object>(); private Object category; public Object getAttribute(String name) { return attributes.get(name); } public void setAttribute(String name, Object value) { attributes.put(name, value); } public Object getCategory() { return category; } public void setCategory(Object category) { this.category = category; } public String toString() { return attributes.toString(); } } static class Tree { private String attribute; private Map<Object, Object> children = new HashMap<Object, Object>(); public Tree(String attribute) { this.attribute = attribute; } public String getAttribute() { return attribute; } public Object getChild(Object attrValue) { return children.get(attrValue); } public void setChild(Object attrValue, Object child) { children.put(attrValue, child); } public Set<Object> getAttributeValues() { return children.keySet(); } }}
运行结果:
以上是“Java如何实现决策树算法”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注编程网行业资讯频道!