译者 | 朱先忠
审校 | 重楼
简介
决策树在机器学习中无处不在,因其直观的输出而备受喜爱。谁不喜欢简单的“if-then”流程图?尽管它们很受欢迎,但令人惊讶的是,要找到一个清晰、循序渐进的解释来分析决策树是如何工作的,还是一项具有相当挑战性的任务。(实际上,我也很尴尬,我也不知道花了多长时间才真正理解决策树算法的工作原理。)
所以,在本文中,我将重点介绍决策树构建的要点。我们将按照从根到最后一个叶子节点(当然还有可视化效果)的顺序来准确解析每个节点中发生的事情及其原因。
【注意】本文中所有图片均由作者本人使用Canva Pro创建。
决策树分类器定义
决策树分类器通过创建一棵倒置的树来进行预测。具体地讲,这种算法从树的顶部开始,提出一个关于数据中重要特征的问题,然后根据答案进行分支生成。当你沿着这些分支往下走时,每一个节点都会提出另一个问题,从而缩小可能性。这个问答游戏一直持续到你到达树的底部——一个叶子节点——在那里你将得到最终的预测或分类结果。
决策树是最重要的机器学习算法之一——它反映了一系列是或否的问题
示例数据集
在本文中,我们将使用人工高尔夫数据集(受【参考文献1】启发)作为示例数据集。该数据集能够根据天气状况预测一个人是否会去打高尔夫。
上图表格中,主要的数据列含义分别是:
- “Outlook(天气状况)”:编码为“晴天(sunny)”、“阴天(overcast)”或者“雨天(rainy)”
- “Temperature(温度)”:对应华氏温度
- “Humidity(湿度)”:用百分数%表示
- “Wind(风)”:是/否有风
- “Play(是否去打高尔夫)”:目标特征
#导入库
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
#加载数据
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)
#预处理数据集
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)
#重新排列各数据列
df = df[['sunny', 'overcast', 'rainy', 'Temperature', 'Humidity', 'Wind', 'Play']]
# 是否出玩和目标确定
X, y = df.drop(columns='Play'), df['Play']
# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
# 显示结果
print(pd.concat([X_train, y_train], axis=1), '\n')
print(pd.concat([X_test, y_test], axis=1))
主要原理
决策树分类器基于信息量最大的特征递归分割数据来进行操作。其工作原理如下:
- 从根节点处的整个数据集开始。
- 选择最佳特征来分割数据(基于基尼不纯度(Gini impurity)等指标)。
- 为选定特征的每个可能值创建子节点。
- 对每个子节点重复步骤2-3,直到满足停止条件(例如,达到最大深度、每个叶子的最小样本或纯叶子节点)。
- 将主要的类型结果值分配给每个叶子节点。
训练步骤
在scikit-learn开源库中,决策树算法被称为CART(分类和回归树)。它可以用于构建一棵二叉树,通常遵循以下步骤:
1. 从根节点中的所有训练样本开始。
从包含所有14个训练样本的根节点开始,我们将找出最佳方式特征和分割数据的最佳点,以开始构建树
2.对于每个特征,执行如下操作:
a.对特征值进行排序。
b.将相邻值之间的所有可能阈值视为潜在的分割点。
在这个根节点中,有23个分割点需要检查。其中,二进制列只有一个分割点。
def potential_split_points(attr_name, attr_values):
sorted_attr = np.sort(attr_values)
unique_values = np.unique(sorted_attr)
split_points = [(unique_values[i] + unique_values[i+1]) / 2 for i in range(len(unique_values) - 1)]
return {attr_name: split_points}
# Calculate and display potential split points for all columns
for column in X_train.columns:
splits = potential_split_points(column, X_train[column])
for attr, points in splits.items():
print(f"{attr:11}: {points}")
3.对于每个潜在的分割点,执行如下操作:
a.计算当前节点的基尼不纯度。
b.计算不纯度的加权平均值。
例如,对于分割点为0.5的特征“sunny(晴天)”,计算数据集两部分的不纯度(也就是基尼不纯度)
另一个例子是,同样的过程也可以对“Temperature(温度)”等连续特征进行处理。
def gini_impurity(y):
p = np.bincount(y) / len(y)
return 1 - np.sum(p**2)
def weighted_average_impurity(y, split_index):
n = len(y)
left_impurity = gini_impurity(y[:split_index])
right_impurity = gini_impurity(y[split_index:])
return (split_index * left_impurity + (n - split_index) * right_impurity) / n
# 排序“sunny”特征和相应的标签
sunny = X_train['sunny']
sorted_indices = np.argsort(sunny)
sorted_sunny = sunny.iloc[sorted_indices]
sorted_labels = y_train.iloc[sorted_indices]
#查找0.5的分割索引
split_index = np.searchsorted(sorted_sunny, 0.5, side='right')
#计算不纯度
impurity = weighted_average_impurity(sorted_labels, split_index)
print(f"Weighted average impurity for 'sunny' at split point 0.5: {impurity:.3f}")
4.计算完所有特征和分割点的所有不纯度后,选择最低的那一个。
分割点为0.5的“overcast(阴天)”特征给出了最低的不纯度。这意味着,该分割将是所有其他分割点中最纯粹的一个!
def calculate_split_impurities(X, y):
split_data = []
for feature in X.columns:
sorted_indices = np.argsort(X[feature])
sorted_feature = X[feature].iloc[sorted_indices]
sorted_y = y.iloc[sorted_indices]
unique_values = sorted_feature.unique()
split_points = (unique_values[1:] + unique_values[:-1]) / 2
for split in split_points:
split_index = np.searchsorted(sorted_feature, split, side='right')
impurity = weighted_average_impurity(sorted_y, split_index)
split_data.append({
'feature': feature,
'split_point': split,
'weighted_avg_impurity': impurity
})
return pd.DataFrame(split_data)
# 计算所有特征的分割不纯度
calculate_split_impurities(X_train, y_train).round(3)
5.根据所选特征和分割点创建两个子节点:
- 左子对象:特征值<=分割点的样本
- 右侧子对象:具有特征值>分割点的样本
选定的分割点将数据分割成两部分。由于一部分已经是纯的(右侧!这就是为什么它的不纯度很低!),我们只需要在左侧节点上继续迭代树。
对每个子节点递归重复上述步骤2-5。您还可以停止,直到满足停止条件(例如,达到最大深度、每个叶节点的最小样本数或最小不纯度减少)。
#计算选定指标中的分割不纯度
selected_index = [4,8,3,13,7,9,10] # 根据您要检查的索引来更改它
calculate_split_impurities(X_train.iloc[selected_index], y_train.iloc[selected_index]).round(3)
from sklearn.tree import DecisionTreeClassifier
#上面的整个训练阶段都是像这样在sklearn中完成的
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X_train, y_train)
算法结束时的输出树形式
叶子节点的类型标签对应于到达该节点的训练样本的多数类型。
上图中右侧的树表示用于分类的最后那棵树。此时,我们不再需要训练样本了。
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
#打印决策树
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=X.columns, class_names=['Not Play', 'Play'])
plt.show()
在这个scikit-learn输出中,还存储了非叶子节点的信息,如样本数量和节点中每个类型的数量(值)
分类步骤
接下来,我们来了解训练决策树生成后,预测过程的工作原理,共分4步:
- 从训练好的决策树的根节点开始。
- 评估当前节点的特征和分割条件。
- 在每个后续节点上重复步骤2,直到到达叶子节点。
- 叶子节点的类型标签成为新实例的预测结果。
算法中,我们只需要树运算中所使用的列指标。除了“overcast(阴天)”和“Temperature(温度)”这两个之外,其他值在预测中并不重要
# 开始预测
y_pred = dt_clf.predict(X_test)
print(y_pred)
评估步骤
决策树描述了足够的准确性信息。由于我们的树只检查两个特征,因此这棵树可能无法很好地捕获测试集特征。
# 评估分类器
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
关键参数分析
上面决策树算法中,使用了好几个控制其增长和复杂性的重要参数:
- 最大深度(Max Depth):这个参数用于设置树的最大深度,这个参数可以成为防止训练过拟合的一个很有价值的工具。
友好提示:读者不妨考虑从一棵浅树(可能3-5层深)开始,然后逐渐增加深度。
从一棵浅树开始(例如,深度为3-5),逐渐增加,直到找到模型复杂性和验证数据性能之间的最佳平衡。
- 最小样本分割(Min Samples Split):此参数确定分割内部节点所需的最小样本数。
友好提示:读者不妨将其设置为更高一些的值(约占训练数据的5-10%),这可以帮助防止树创建太多小而特定的分割,从而可能导致无法很好地推广到新数据。
- 最小样本叶节点(Min Samples Leaf):这个参数指定叶子节点所需的最小样本数。
友好提示:选择一个值,确保每个叶子代表数据的一个有意义的子集(大约占训练数据的1-5%)。这种方法有助于避免过于具体的预测。
- 标准(Criterion):这是一个函数参数指标,用于衡量一次分割的质量(通常为代表基尼不纯度的“gini(基尼)”或信息增益的“entropy(熵)”)。
友好提示:虽然基尼不纯度的计算通常更简单、更快,但熵在多类型算法问题方面往往表现得更好。也就是说,这两种技术经常给出类似的运算结果。
分割点为0.5的“sunny(晴天)”的熵计算示例
算法优、缺点分析
与机器学习中的任何算法一样,决策树也有其优势和局限性。
优点:
- 可解释性:易于理解,决策过程能够可视化。
- 无特征缩放:可以在不进行规范化的情况下处理数值和分类数据。
- 处理非线性关系:可以捕获数据中的复杂模式。
- 特征重要性:能够明确指出哪些特征对预测最重要。
缺点:
- 过拟合:容易创建过于复杂的树,这些树往往不能很好地泛化,尤其是在小数据集的情况下。
- 不稳定性:数据中的微小变化可能会导致生成完全不同的树。
- 带有不平衡数据集的偏见:可能偏向于主导类型。
- 无法推断:无法在训练数据范围之外进行预测。
在我们本文提供的高尔夫示例中,决策树可能会根据天气条件创建非常准确和可解释的规则,以决定是否打高尔夫。然而,如果没有正确修剪或数据集很小的话,它可能会过度拟合特定的条件组合。
结论
无论如何,决策树分类器还是解决机器学习中许多类型问题的好工具。它们易于理解,可以处理复杂的数据,并向我们展示它们是如何做出决策的。这使得它们广泛应用于从商业到医学的许多领域。虽然决策树功能强大且可解释,但它们通常被用作更高级集成方法的构建块,如随机森林或梯度增强机等算法中。
简化型决策树分类器算法
下面给出本文示例工程对应的一棵简化型决策树的分类器算法的完整代码。
#导入库
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.tree import plot_tree, DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)
#准备数据
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)
# 分割数据
X, y = df.drop(columns='Play'), df['Play']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
# 训练模型
dt_clf = DecisionTreeClassifier(
max_depth=None, # 树的最大深度
min_samples_split=2, # 分割内部节点所需的最小样本数
min_samples_leaf=1, # 在一个叶节点上所需的最小样本数
criterion='gini' # 测试分割质量的函数
)
dt_clf.fit(X_train, y_train)
# 开始预测
y_pred = dt_clf.predict(X_test)
#评估模型
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
#绘制树
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=X.columns,
class_names=['Not Play', 'Play'], impurity=False)
plt.show()
相关阅读
有关决策树分类器及其在scikit-learn中的实现的详细说明,读者可以参考官方文档(参考2),其中提供了有关其使用和参数的更为全面的信息。
技术环境
本文示例项目使用了Python 3.7和scikit-learn 1.5开源库。虽然所讨论的概念普遍适用,但具体的代码实现可能因版本不同而略有不同。
关于插图
除非另有说明;否则,文中所有图片均由作者创建,并包含经Canva Pro许可的设计元素。
参考资料
- T. M. Mitchell,Machine Learning(机器学习) (1997),McGraw-Hill Science/Engineering/Math,第59页。
- F. Pedregosa等人,《Scikit-learn: Machine Learning in Python》,Journal of Machine Learning Research(机器学习研究杂志) 2011年,第12卷,第2825-2830页。在线可访问地址:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html。
译者介绍
朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。
原文Decision Tree Classifier, Explained: A Visual Guide with Code Examples for Beginners,作者:Samy Baladram