Skip to content

第2章:决策树

决策树是一种直观的分类和回归方法,它通过学习简单的决策规则来预测目标变量的值。本章将详细介绍决策树的原理、应用场景以及如何使用Scikit-Learn实现决策树模型。

2.1 决策树原理

2.1.1 树的构建

决策树通过递归地选择最佳特征和分割点来构建。每个内部节点代表一个特征上的测试,每个分支代表测试的结果,每个叶节点代表一个决策或预测结果。

2.1.2 特征选择

特征选择是决策树构建的关键步骤,常用的方法包括信息增益和基尼不纯度。

2.1.3 树的剪枝

为了防止过拟合,决策树需要进行剪枝,包括预剪枝和后剪枝技术。

2.2 决策树应用

2.2.1 分类问题

决策树可以用于分类问题,通过学习数据的特征来预测样本的类别。

2.2.2 回归问题

决策树也可以用于回归问题,预测连续值。

2.2.3 特征重要性评估

决策树可以评估每个特征对模型的贡献度,即特征的重要性。

2.3 决策树代码实现

2.3.1 数据准备

我们将使用Scikit-Learn的鸢尾花数据集来演示决策树的实现。

python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

2.3.2 模型训练

使用Scikit-Learn的DecisionTreeClassifier类来训练决策树模型。

python
from sklearn.tree import DecisionTreeClassifier

# 创建决策树分类器实例
dt_classifier = DecisionTreeClassifier(random_state=42)

# 训练模型
dt_classifier.fit(X_train, y_train)

2.3.3 模型评估

评估模型在测试集上的性能。

python
from sklearn.metrics import accuracy_score

# 在测试集上进行预测
y_pred = dt_classifier.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

2.3.4 可视化决策树

使用plot_tree函数可视化决策树。

python
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

# 可视化决策树
plt.figure(figsize=(20,10))
plot_tree(dt_classifier, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.title("Decision Tree Visualization")
plt.show()

2.4 本章小结

本章介绍了决策树的基本原理、应用场景,并使用Scikit-Learn实现了一个简单的决策树模型。通过数据准备、模型训练、评估和可视化,我们可以看到决策树在分类问题中的有效性。理解决策树的工作原理和代码实现对于深入学习机器学习算法至关重要。