首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >机器学习全能决策树:CART(Classification and Regression Tree)原理、手动计算与Python/Java双代码实战

机器学习全能决策树:CART(Classification and Regression Tree)原理、手动计算与Python/Java双代码实战

原创
作者头像
jack.yang
发布2026-03-29 15:40:39
发布2026-03-29 15:40:39
2370
举报
文章被收录于专栏:大模型系列大模型系列

关键词:机器学习、CART算法、分类回归树、基尼系数、MSE、决策树、Python CART、Java CART、sklearn DecisionTree、代价复杂度剪枝

一句话答案:CART 是唯一同时支持分类与回归的主流决策树算法——分类用基尼系数,回归用最小平方误差,并采用二叉树结构 + 代价复杂度剪枝,是 sklearn、XGBoost 等框架的基石!

如果你在搜索:

  • “CART 和 ID3/C4.5 有什么区别?”
  • “基尼系数怎么算?”
  • “CART 如何做回归预测?”
  • “Python 和 Java 怎么手写 CART?”

那么,这篇文章就是为你写的——从分类到回归,从分裂到剪枝,一步不跳


一、什么是 CART?它为何成为现代决策树的“标准”?

CART(Classification and Regression Tree)由 Breiman 等人在 1984 年提出,是首个统一处理分类与回归任务的决策树算法。

🔑 核心特点(vs ID3/C4.5)

特性

ID3/C4.5

CART

树结构

多叉树(一个特征多个分支)

严格二叉树(每个节点仅两个子节点)

分裂标准

信息增益 / 增益率

分类:基尼系数;回归:MSE

剪枝方法

悲观剪枝(C4.5)

代价复杂度剪枝(CCP)

任务支持

仅分类

✅ 分类 + 回归

连续特征处理

需预离散化(ID3)或自动分箱(C4.5)

天然支持,直接找最优切分点

💡 CART 的二叉设计使其易于集成(如随机森林、GBDT),成为现代 ML 框架的事实标准。


二、数学原理:分类用基尼,回归用 MSE

📌 1. 分类任务:基尼不纯度(Gini Impurity)

衡量一个数据集的“混乱程度”:


📌 2. 回归任务:最小平方误差(MSE)

目标:最小化预测值与真实值的平方误差。


三、手工推演 Part 1:分类任务(基尼系数)

📊 数据集:是否批准贷款

收入(万)

房产

批准?

30

50

70

40

60

目标:构建 CART 分类树。


🔢 步骤1:尝试所有可能的二元划分

候选1:按“房产”划分(天然二元)
候选2:按“收入”划分(需找切分点)

排序收入:30(否), 40(否), 50(是), 60(是), 70(是)

候选切分点:35, 45, 55, 65

  • t=45
    • ≤45:[30(否), 40(否)] → Gini=0
    • 45:[50(是),60(是),70(是)] → Gini=0
    • 加权 Gini = 0 ← 最优!

✅ 选择 收入 ≤ 45 作为根节点划分。


🌲 最终分类树

代码语言:javascript
复制
        收入 ≤ 45?
         /      \
       是        否
      (否)       (是)

💡 即使“房产”是重要特征,CART 仍可能因二叉+全局最优选择数值特征。


四、手工推演 Part 2:回归任务(MSE)

📊 数据集:房屋面积 vs 价格(万元)

面积(m²)

价格

50

100

70

140

90

180

110

200

130

220

目标:预测新房子的价格。


🔢 步骤1:计算根节点 MSE


🔢 步骤2:尝试切分点(面积)

排序后候选切分点:60, 80, 100, 120

t=100:
  • 左(≤100):[50,70,90] → y=[100,140,180] → 均值=140 → MSE_L = ((40^2 + 0 + 40^2)=3200)
  • 右(>100):[110,130] → y=[200,220] → 均值=210 → MSE_R = ((10^2 + 10^2)=200)
  • 总 MSE = 3200 + 200 = 3400
  • ΔMSE = 8080 - 3400 = 4680(最大下降)

✅ 选择 面积 ≤ 100 为根划分。


🌲 最终回归树

代码语言:javascript
复制
      面积 ≤ 100?
       /        \
    是(140)     否(210)

预测:若面积=80 → 预测价格=140万;面积=120 → 预测=210万。


五、Python 实现(手写 CART 核心)

代码语言:javascript
复制
import numpy as np

class Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature        # 划分特征索引
        self.threshold = threshold    # 划分阈值
        self.left = left              # 左子树
        self.right = right            # 右子树
        self.value = value            # 叶节点预测值(分类:类别;回归:均值)

class CART:
    def __init__(self, task='classification', max_depth=5, min_samples_split=2):
        self.task = task
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root = None

    def _gini(self, y):
        if len(y) == 0: return 0
        probs = np.bincount(y) / len(y)
        return 1 - np.sum(probs ** 2)

    def _mse(self, y):
        if len(y) == 0: return 0
        return np.sum((y - np.mean(y)) ** 2)

    def _split(self, X, y, feature_idx, threshold):
        left_mask = X[:, feature_idx] <= threshold
        return X[left_mask], y[left_mask], X[~left_mask], y[~left_mask]

    def _best_split(self, X, y):
        best_gain = -1
        best_f, best_t = None, None
        n_features = X.shape[1]

        for f in range(n_features):
            thresholds = np.unique(X[:, f])
            for t in thresholds[:-1]:  # 避免空集
                X_l, y_l, X_r, y_r = self._split(X, y, f, t)
                if len(y_l) == 0 or len(y_r) == 0:
                    continue

                if self.task == 'classification':
                    gain = self._gini(y) - (
                        len(y_l)/len(y) * self._gini(y_l) +
                        len(y_r)/len(y) * self._gini(y_r)
                    )
                else:  # regression
                    gain = self._mse(y) - (self._mse(y_l) + self._mse(y_r))

                if gain > best_gain:
                    best_gain = gain
                    best_f, best_t = f, t

        return best_f, best_t, best_gain

    def _build_tree(self, X, y, depth=0):
        # 停止条件
        if (depth >= self.max_depth or 
            len(y) < self.min_samples_split or 
            len(np.unique(y)) == 1):
            if self.task == 'classification':
                return Node(value=np.bincount(y).argmax())
            else:
                return Node(value=np.mean(y))

        feature, threshold, gain = self._best_split(X, y)
        if feature is None:
            if self.task == 'classification':
                return Node(value=np.bincount(y).argmax())
            else:
                return Node(value=np.mean(y))

        X_l, y_l, X_r, y_r = self._split(X, y, feature, threshold)
        left = self._build_tree(X_l, y_l, depth+1)
        right = self._build_tree(X_r, y_r, depth+1)

        return Node(feature=feature, threshold=threshold, left=left, right=right)

    def fit(self, X, y):
        self.root = self._build_tree(np.array(X), np.array(y))

    def _predict_sample(self, x, node):
        if node.value is not None:
            return node.value
        if x[node.feature] <= node.threshold:
            return self._predict_sample(x, node.left)
        else:
            return self._predict_sample(x, node.right)

    def predict(self, X):
        return [self._predict_sample(x, self.root) for x in X]

# === 测试分类 ===
X_cls = [[30, 0], [50, 1], [70, 1], [40, 0], [60, 0]]
y_cls = [0, 1, 1, 0, 1]  # 0=否, 1=是
cart_cls = CART(task='classification')
cart_cls.fit(X_cls, y_cls)
print("分类预测:", cart_cls.predict())  # [0]

# === 测试回归 ===
X_reg = [[50], [70], [90], [110], [130]]
y_reg = [100, 140, 180, 200, 220]
cart_reg = CART(task='regression')
cart_reg.fit(X_reg, y_reg)
print("回归预测:", cart_reg.predict())  # [140.0]

六、Java 实现(简化版 CART 核心)

代码语言:javascript
复制
import java.util.*;

public class CART {
    static class Node {
        Integer feature;
        Double threshold;
        Node left, right;
        Object value; // Integer for classification, Double for regression
        Node(Object value) { this.value = value; }
        Node(int f, double t) { this.feature = f; this.threshold = t; }
    }

    private String task;
    private int maxDepth;
    private Node root;

    public CART(String task) {
        this.task = task;
        this.maxDepth = 5;
    }

    private double gini(int[] labels) {
        if (labels.length == 0) return 0;
        Map<Integer, Integer> count = new HashMap<>();
        for (int y : labels) count.put(y, count.getOrDefault(y, 0) + 1);
        double impurity = 1.0;
        for (int c : count.values()) {
            double p = (double) c / labels.length;
            impurity -= p * p;
        }
        return impurity;
    }

    private double mse(double[] y) {
        if (y.length == 0) return 0;
        double mean = Arrays.stream(y).average().orElse(0);
        return Arrays.stream(y).map(v -> (v - mean) * (v - mean)).sum();
    }

    private Node buildTree(double[][] X, int[] y, int depth) {
        if (depth >= maxDepth || y.length < 2) {
            if ("classification".equals(task)) {
                return new Node(majorityClass(y));
            } else {
                return new Node(Arrays.stream(y).average().orElse(0));
            }
        }

        // Find best split (simplified: only first feature)
        int bestFeature = 0;
        double bestThreshold = 0;
        double bestGain = -1;

        // Sort by feature 0
        List<Integer> indices = new ArrayList<>();
        for (int i = 0; i < X.length; i++) indices.add(i);
        indices.sort(Comparator.comparingDouble(i -> X[i][0]));

        for (int i = 0; i < indices.size() - 1; i++) {
            int idx1 = indices.get(i), idx2 = indices.get(i + 1);
            if (y[idx1] != y[idx2]) {
                double t = (X[idx1][0] + X[idx2][0]) / 2;
                // Split
                List<Integer> leftIdx = new ArrayList<>(), rightIdx = new ArrayList<>();
                for (int j = 0; j < X.length; j++) {
                    if (X[j][0] <= t) leftIdx.add(j); else rightIdx.add(j);
                }
                if (leftIdx.isEmpty() || rightIdx.isEmpty()) continue;

                // Compute gain
                double gain;
                if ("classification".equals(task)) {
                    int[] yArr = y;
                    int[] yLeft = leftIdx.stream().mapToInt(j -> yArr[j]).toArray();
                    int[] yRight = rightIdx.stream().mapToInt(j -> yArr[j]).toArray();
                    gain = gini(y) - (leftIdx.size()/(double)y.length * gini(yLeft) +
                                      rightIdx.size()/(double)y.length * gini(yRight));
                } else {
                    double[] yD = Arrays.stream(y).mapToDouble(v -> v).toArray();
                    double[] yLeft = leftIdx.stream().mapToDouble(j -> yD[j]).toArray();
                    double[] yRight = rightIdx.stream().mapToDouble(j -> yD[j]).toArray();
                    gain = mse(yD) - (mse(yLeft) + mse(yRight));
                }

                if (gain > bestGain) {
                    bestGain = gain;
                    bestThreshold = t;
                }
            }
        }

        if (bestGain <= 0) {
            if ("classification".equals(task)) {
                return new Node(majorityClass(y));
            } else {
                return new Node(Arrays.stream(y).average().orElse(0));
            }
        }

        // Recursively build
        List<Integer> leftIdx = new ArrayList<>(), rightIdx = new ArrayList<>();
        for (int i = 0; i < X.length; i++) {
            if (X[i][0] <= bestThreshold) leftIdx.add(i); else rightIdx.add(i);
        }
        double[][] XLeft = leftIdx.stream().mapToDouble(i -> X[i]).toArray();
        double[][] XRight = rightIdx.stream().mapToDouble(i -> X[i]).toArray();
        int[] yLeft = leftIdx.stream().mapToInt(i -> y[i]).toArray();
        int[] yRight = rightIdx.stream().mapToInt(i -> y[i]).toArray();

        Node node = new Node(bestFeature, bestThreshold);
        node.left = buildTree(XLeft, yLeft, depth + 1);
        node.right = buildTree(XRight, yRight, depth + 1);
        return node;
    }

    private int majorityClass(int[] y) {
        Map<Integer, Integer> count = new HashMap<>();
        for (int label : y) count.put(label, count.getOrDefault(label, 0) + 1);
        return Collections.max(count.entrySet(), Map.Entry.comparingByValue()).getKey();
    }

    public void fit(double[][] X, int[] y) {
        this.root = buildTree(X, y, 0);
    }

    // 预测方法略(递归遍历树)
}

💡 完整 Java 实现需处理多特征、通用预测等,此处展示核心分裂逻辑。


七、CART 的剪枝:代价复杂度剪枝(CCP)

CART 使用 α 参数 控制树复杂度:


八、优缺点 & 适用场景总结

优点

缺点

✅ 统一支持分类与回归

❌ 单棵树易过拟合(需剪枝或集成)

✅ 二叉结构利于集成(RF/GBDT)

❌ 对噪声敏感

✅ 天然处理连续特征

❌ 忽略特征间线性关系

✅ 可解释性强

❌ 不稳定(数据微变 → 树大变)

🎯 最佳应用场景:

  • 快速原型开发(sklearn 一行代码)
  • 作为集成学习基学习器(随机森林、XGBoost)
  • 需要回归预测的场景(房价、销量)
  • 中小结构化数据集

✅ 结语

CART 用极简的二叉设计,打通了分类与回归的任督二脉。它不仅是独立模型,更是现代集成学习的基石

记住:在机器学习中,简单而通用的设计,往往最具生命力

现在,你已经能:

  • 手动计算基尼系数与 MSE 分裂
  • 用 Python/Java 从零实现 CART
  • 理解其在 sklearn 和集成算法中的核心地位

相关链接

  • 📂 大模型技术专栏: 欢迎您到访 「大模型系列」。 在这个由参数驱动、以数据为燃料的新智能时代,大语言模型(LLM)已不再是实验室里的前沿概念,而是正在重塑搜索、办公、编程、教育、医疗乃至整个数字世界的底层引擎。从 GPT 到 Llama,从 Claude 到 Qwen,从推理到多模态,大模型正以前所未有的速度进化——它们既是工具,也是平台,更可能是下一代人机交互的“操作系统”。 本系列将带你:
    • 🔍 深入原理:从 Transformer 架构、注意力机制到训练范式(预训练、微调、RLHF);
    • ⚙️ 动手实践:本地部署、模型微调、RAG 构建、Agent 设计等实战指南;
    • 🧠 理解边界:幻觉、偏见、安全对齐、推理瓶颈与当前能力天花板;
    • 🌍 洞察趋势:开源 vs 闭源、端侧部署、MoE 架构、世界模型与 AGI 路径;
    • 💼 落地应用:如何在企业中安全、高效、低成本地集成大模型能力。

    无论你是想写代码调用 API 的开发者,设计 AI 产品的 PM,评估技术路线的管理者,还是单纯好奇智能本质的思考者,这里都有值得你驻足的内容。 不追 hype,只讲逻辑;不谈玄学,专注可复现的认知。 让我们一起,在这场百年一遇的智能革命中,看得更清,走得更稳 https://cloud.tencent.com/developer/column/107314

  • 👤 关于作者专注技术落地,深耕硬核干货 本文作者致力于大模型相关技术的生态建设与实战落地。不同于浅层的概念科普,作者坚持 “手算 + 代码” 的深度分享模式,主张通过手动推演理解算法本质,结合生产级代码验证理论可行性。 请关注我主页:https://cloud.tencent.com/developer/user/2276240

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、什么是 CART?它为何成为现代决策树的“标准”?
    • 🔑 核心特点(vs ID3/C4.5)
  • 二、数学原理:分类用基尼,回归用 MSE
    • 📌 1. 分类任务:基尼不纯度(Gini Impurity)
    • 📌 2. 回归任务:最小平方误差(MSE)
  • 三、手工推演 Part 1:分类任务(基尼系数)
    • 📊 数据集:是否批准贷款
    • 🔢 步骤1:尝试所有可能的二元划分
      • 候选1:按“房产”划分(天然二元)
      • 候选2:按“收入”划分(需找切分点)
    • 🌲 最终分类树
  • 四、手工推演 Part 2:回归任务(MSE)
    • 📊 数据集:房屋面积 vs 价格(万元)
    • 🔢 步骤1:计算根节点 MSE
    • 🔢 步骤2:尝试切分点(面积)
      • t=100:
    • 🌲 最终回归树
  • 五、Python 实现(手写 CART 核心)
  • 六、Java 实现(简化版 CART 核心)
  • 七、CART 的剪枝:代价复杂度剪枝(CCP)
  • 八、优缺点 & 适用场景总结
    • 🎯 最佳应用场景:
  • ✅ 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档