文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

机器学习利器——决策树分类器深度解析

2024-11-29 19:35

关注

译者 | 朱先忠

审校 | 重楼

简介

决策树在机器学习中无处不在,因其直观的输出而备受喜爱。谁不喜欢简单的“if-then”流程图?尽管它们很受欢迎,但令人惊讶的是,要找到一个清晰、循序渐进的解释来分析决策树是如何工作的,还是一项具有相当挑战性的任务。(实际上,我也很尴尬,我也不知道花了多长时间才真正理解决策树算法的工作原理。)

所以,在本文中,我将重点介绍决策树构建的要点。我们将按照从根到最后一个叶子节点(当然还有可视化效果)的顺序来准确解析每个节点中发生的事情及其原因。

【注意】本文中所有图片均由作者本人使用Canva Pro创建。

决策树分类器定义

决策树分类器通过创建一棵倒置的树来进行预测。具体地讲,这种算法从树的顶部开始,提出一个关于数据中重要特征的问题,然后根据答案进行分支生成。当你沿着这些分支往下走时,每一个节点都会提出另一个问题,从而缩小可能性。这个问答游戏一直持续到你到达树的底部——一个叶子节点——在那里你将得到最终的预测或分类结果。

决策树是最重要的机器学习算法之一——它反映了一系列是或否的问题

示例数据集

在本文中,我们将使用人工高尔夫数据集(受【参考文献1】启发)作为示例数据集。该数据集能够根据天气状况预测一个人是否会去打高尔夫。

上图表格中,主要的数据列含义分别是:

#导入库
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np

#加载数据
dataset_dict = {
    'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
    'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
    'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
    'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
    'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)

#预处理数据集
df = pd.get_dummies(df, columns=['Outlook'],  prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)

#重新排列各数据列
df = df[['sunny', 'overcast', 'rainy', 'Temperature', 'Humidity', 'Wind', 'Play']]

# 是否出玩和目标确定
X, y = df.drop(columns='Play'), df['Play']

# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)

# 显示结果
print(pd.concat([X_train, y_train], axis=1), '\n')
print(pd.concat([X_test, y_test], axis=1))

主要原理

决策树分类器基于信息量最大的特征递归分割数据来进行操作。其工作原理如下:

训练步骤

在scikit-learn开源库中,决策树算法被称为CART(分类和回归树)。它可以用于构建一棵二叉树,通常遵循以下步骤:

1. 从根节点中的所有训练样本开始。

从包含所有14个训练样本的根节点开始,我们将找出最佳方式特征和分割数据的最佳点,以开始构建树

2.对于每个特征,执行如下操作:

a.对特征值进行排序。

b.将相邻值之间的所有可能阈值视为潜在的分割点。

在这个根节点中,有23个分割点需要检查。其中,二进制列只有一个分割点。

def potential_split_points(attr_name, attr_values):
    sorted_attr = np.sort(attr_values)
    unique_values = np.unique(sorted_attr)
    split_points = [(unique_values[i] + unique_values[i+1]) / 2 for i in range(len(unique_values) - 1)]
    return {attr_name: split_points}

# Calculate and display potential split points for all columns
for column in X_train.columns:
    splits = potential_split_points(column, X_train[column])
    for attr, points in splits.items():
        print(f"{attr:11}: {points}")

3.对于每个潜在的分割点,执行如下操作:

a.计算当前节点的基尼不纯度。

b.计算不纯度的加权平均值。

例如,对于分割点为0.5的特征“sunny(晴天)”,计算数据集两部分的不纯度(也就是基尼不纯度)

另一个例子是,同样的过程也可以对“Temperature(温度)”等连续特征进行处理。

def gini_impurity(y):
    p = np.bincount(y) / len(y)
    return 1 - np.sum(p**2)

def weighted_average_impurity(y, split_index):
    n = len(y)
    left_impurity = gini_impurity(y[:split_index])
    right_impurity = gini_impurity(y[split_index:])
    return (split_index * left_impurity + (n - split_index) * right_impurity) / n

# 排序“sunny”特征和相应的标签
sunny = X_train['sunny']
sorted_indices = np.argsort(sunny)
sorted_sunny = sunny.iloc[sorted_indices]
sorted_labels = y_train.iloc[sorted_indices]

#查找0.5的分割索引
split_index = np.searchsorted(sorted_sunny, 0.5, side='right')

#计算不纯度
impurity = weighted_average_impurity(sorted_labels, split_index)

print(f"Weighted average impurity for 'sunny' at split point 0.5: {impurity:.3f}")

4.计算完所有特征和分割点的所有不纯度后,选择最低的那一个。

分割点为0.5的“overcast(阴天)”特征给出了最低的不纯度。这意味着,该分割将是所有其他分割点中最纯粹的一个!

def calculate_split_impurities(X, y):
    split_data = []

    for feature in X.columns:
        sorted_indices = np.argsort(X[feature])
        sorted_feature = X[feature].iloc[sorted_indices]
        sorted_y = y.iloc[sorted_indices]

        unique_values = sorted_feature.unique()
        split_points = (unique_values[1:] + unique_values[:-1]) / 2

        for split in split_points:
            split_index = np.searchsorted(sorted_feature, split, side='right')
            impurity = weighted_average_impurity(sorted_y, split_index)
            split_data.append({
                'feature': feature,
                'split_point': split,
                'weighted_avg_impurity': impurity
            })

    return pd.DataFrame(split_data)

# 计算所有特征的分割不纯度
calculate_split_impurities(X_train, y_train).round(3)

5.根据所选特征和分割点创建两个子节点:

选定的分割点将数据分割成两部分。由于一部分已经是纯的(右侧!这就是为什么它的不纯度很低!),我们只需要在左侧节点上继续迭代树。

对每个子节点递归重复上述步骤2-5。您还可以停止,直到满足停止条件(例如,达到最大深度、每个叶节点的最小样本数或最小不纯度减少)。

#计算选定指标中的分割不纯度
selected_index = [4,8,3,13,7,9,10] # 根据您要检查的索引来更改它
calculate_split_impurities(X_train.iloc[selected_index], y_train.iloc[selected_index]).round(3)

from sklearn.tree import DecisionTreeClassifier

#上面的整个训练阶段都是像这样在sklearn中完成的
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X_train, y_train)

算法结束时的输出树形式

叶子节点的类型标签对应于到达该节点的训练样本的多数类型。

上图中右侧的树表示用于分类的最后那棵树。此时,我们不再需要训练样本了。

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
#打印决策树
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=X.columns, class_names=['Not Play', 'Play'])
plt.show()

在这个scikit-learn输出中,还存储了非叶子节点的信息,如样本数量和节点中每个类型的数量(值)

分类步骤

接下来,我们来了解训练决策树生成后,预测过程的工作原理,共分4步:

算法中,我们只需要树运算中所使用的列指标。除了“overcast(阴天)”和“Temperature(温度)”这两个之外,其他值在预测中并不重要

# 开始预测
y_pred = dt_clf.predict(X_test)
print(y_pred)

评估步骤

决策树描述了足够的准确性信息。由于我们的树只检查两个特征,因此这棵树可能无法很好地捕获测试集特征。

# 评估分类器
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

关键参数分析

上面决策树算法中,使用了好几个控制其增长和复杂性的重要参数:

友好提示:读者不妨考虑从一棵浅树(可能3-5层深)开始,然后逐渐增加深度。

从一棵浅树开始(例如,深度为3-5),逐渐增加,直到找到模型复杂性和验证数据性能之间的最佳平衡。

友好提示:读者不妨将其设置为更高一些的值(约占训练数据的5-10%),这可以帮助防止树创建太多小而特定的分割,从而可能导致无法很好地推广到新数据。

友好提示:选择一个值,确保每个叶子代表数据的一个有意义的子集(大约占训练数据的1-5%)。这种方法有助于避免过于具体的预测。

友好提示:虽然基尼不纯度的计算通常更简单、更快,但熵在多类型算法问题方面往往表现得更好。也就是说,这两种技术经常给出类似的运算结果。

分割点为0.5的“sunny(晴天)”的熵计算示例

算法优、缺点分析

与机器学习中的任何算法一样,决策树也有其优势和局限性。

优点:

缺点:

在我们本文提供的高尔夫示例中,决策树可能会根据天气条件创建非常准确和可解释的规则,以决定是否打高尔夫。然而,如果没有正确修剪或数据集很小的话,它可能会过度拟合特定的条件组合。

结论

无论如何,决策树分类器还是解决机器学习中许多类型问题的好工具。它们易于理解,可以处理复杂的数据,并向我们展示它们是如何做出决策的。这使得它们广泛应用于从商业到医学的许多领域。虽然决策树功能强大且可解释,但它们通常被用作更高级集成方法的构建块,如随机森林或梯度增强机等算法中。

简化型决策树分类器算法

下面给出本文示例工程对应的一棵简化型决策树的分类器算法的完整代码。

#导入库
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.tree import plot_tree, DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据
dataset_dict = {
    'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
    'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
    'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
    'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
    'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)

#准备数据
df = pd.get_dummies(df, columns=['Outlook'],  prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)

# 分割数据
X, y = df.drop(columns='Play'), df['Play']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)

# 训练模型
dt_clf = DecisionTreeClassifier(
    max_depth=None,           # 树的最大深度
    min_samples_split=2,      # 分割内部节点所需的最小样本数
    min_samples_leaf=1,       # 在一个叶节点上所需的最小样本数
    criterion='gini'          # 测试分割质量的函数
)
dt_clf.fit(X_train, y_train)

# 开始预测
y_pred = dt_clf.predict(X_test)

#评估模型
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

#绘制树
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=X.columns,
          class_names=['Not Play', 'Play'], impurity=False)
plt.show()

相关阅读

有关决策树分类器及其在scikit-learn中的实现的详细说明,读者可以参考官方文档(参考2),其中提供了有关其使用和参数的更为全面的信息。

技术环境

本文示例项目使用了Python 3.7和scikit-learn 1.5开源库。虽然所讨论的概念普遍适用,但具体的代码实现可能因版本不同而略有不同。

关于插图

除非另有说明;否则,文中所有图片均由作者创建,并包含经Canva Pro许可的设计元素。

参考资料

  1. T. M. Mitchell,Machine Learning(机器学习) (1997),McGraw-Hill Science/Engineering/Math,第59页。
  2. F. Pedregosa等人,《Scikit-learn: Machine Learning in Python》,Journal of Machine Learning Research(机器学习研究杂志) 2011年,第12卷,第2825-2830页。在线可访问地址:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html。

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。

原文Decision Tree Classifier, Explained: A Visual Guide with Code Examples for Beginners,作者:Samy Baladram

来源:51CTO内容精选内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯