关键词:机器学习、CART算法、分类回归树、基尼系数、MSE、决策树、Python CART、Java CART、sklearn DecisionTree、代价复杂度剪枝
一句话答案:CART 是唯一同时支持分类与回归的主流决策树算法——分类用基尼系数,回归用最小平方误差,并采用二叉树结构 + 代价复杂度剪枝,是 sklearn、XGBoost 等框架的基石!
如果你在搜索:
那么,这篇文章就是为你写的——从分类到回归,从分裂到剪枝,一步不跳。
CART(Classification and Regression Tree)由 Breiman 等人在 1984 年提出,是首个统一处理分类与回归任务的决策树算法。
特性 | ID3/C4.5 | CART |
|---|---|---|
树结构 | 多叉树(一个特征多个分支) | 严格二叉树(每个节点仅两个子节点) |
分裂标准 | 信息增益 / 增益率 | 分类:基尼系数;回归:MSE |
剪枝方法 | 悲观剪枝(C4.5) | 代价复杂度剪枝(CCP) |
任务支持 | 仅分类 | ✅ 分类 + 回归 |
连续特征处理 | 需预离散化(ID3)或自动分箱(C4.5) | 天然支持,直接找最优切分点 |
💡 CART 的二叉设计使其易于集成(如随机森林、GBDT),成为现代 ML 框架的事实标准。
衡量一个数据集的“混乱程度”:

目标:最小化预测值与真实值的平方误差。

收入(万) | 房产 | 批准? |
|---|---|---|
30 | 无 | 否 |
50 | 有 | 是 |
70 | 有 | 是 |
40 | 无 | 否 |
60 | 无 | 是 |
目标:构建 CART 分类树。

排序收入:30(否), 40(否), 50(是), 60(是), 70(是)
候选切分点:35, 45, 55, 65
✅ 选择 收入 ≤ 45 作为根节点划分。
收入 ≤ 45?
/ \
是 否
(否) (是)💡 即使“房产”是重要特征,CART 仍可能因二叉+全局最优选择数值特征。
面积(m²) | 价格 |
|---|---|
50 | 100 |
70 | 140 |
90 | 180 |
110 | 200 |
130 | 220 |
目标:预测新房子的价格。

排序后候选切分点:60, 80, 100, 120
✅ 选择 面积 ≤ 100 为根划分。
面积 ≤ 100?
/ \
是(140) 否(210)预测:若面积=80 → 预测价格=140万;面积=120 → 预测=210万。
import numpy as np
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature # 划分特征索引
self.threshold = threshold # 划分阈值
self.left = left # 左子树
self.right = right # 右子树
self.value = value # 叶节点预测值(分类:类别;回归:均值)
class CART:
def __init__(self, task='classification', max_depth=5, min_samples_split=2):
self.task = task
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.root = None
def _gini(self, y):
if len(y) == 0: return 0
probs = np.bincount(y) / len(y)
return 1 - np.sum(probs ** 2)
def _mse(self, y):
if len(y) == 0: return 0
return np.sum((y - np.mean(y)) ** 2)
def _split(self, X, y, feature_idx, threshold):
left_mask = X[:, feature_idx] <= threshold
return X[left_mask], y[left_mask], X[~left_mask], y[~left_mask]
def _best_split(self, X, y):
best_gain = -1
best_f, best_t = None, None
n_features = X.shape[1]
for f in range(n_features):
thresholds = np.unique(X[:, f])
for t in thresholds[:-1]: # 避免空集
X_l, y_l, X_r, y_r = self._split(X, y, f, t)
if len(y_l) == 0 or len(y_r) == 0:
continue
if self.task == 'classification':
gain = self._gini(y) - (
len(y_l)/len(y) * self._gini(y_l) +
len(y_r)/len(y) * self._gini(y_r)
)
else: # regression
gain = self._mse(y) - (self._mse(y_l) + self._mse(y_r))
if gain > best_gain:
best_gain = gain
best_f, best_t = f, t
return best_f, best_t, best_gain
def _build_tree(self, X, y, depth=0):
# 停止条件
if (depth >= self.max_depth or
len(y) < self.min_samples_split or
len(np.unique(y)) == 1):
if self.task == 'classification':
return Node(value=np.bincount(y).argmax())
else:
return Node(value=np.mean(y))
feature, threshold, gain = self._best_split(X, y)
if feature is None:
if self.task == 'classification':
return Node(value=np.bincount(y).argmax())
else:
return Node(value=np.mean(y))
X_l, y_l, X_r, y_r = self._split(X, y, feature, threshold)
left = self._build_tree(X_l, y_l, depth+1)
right = self._build_tree(X_r, y_r, depth+1)
return Node(feature=feature, threshold=threshold, left=left, right=right)
def fit(self, X, y):
self.root = self._build_tree(np.array(X), np.array(y))
def _predict_sample(self, x, node):
if node.value is not None:
return node.value
if x[node.feature] <= node.threshold:
return self._predict_sample(x, node.left)
else:
return self._predict_sample(x, node.right)
def predict(self, X):
return [self._predict_sample(x, self.root) for x in X]
# === 测试分类 ===
X_cls = [[30, 0], [50, 1], [70, 1], [40, 0], [60, 0]]
y_cls = [0, 1, 1, 0, 1] # 0=否, 1=是
cart_cls = CART(task='classification')
cart_cls.fit(X_cls, y_cls)
print("分类预测:", cart_cls.predict()) # [0]
# === 测试回归 ===
X_reg = [[50], [70], [90], [110], [130]]
y_reg = [100, 140, 180, 200, 220]
cart_reg = CART(task='regression')
cart_reg.fit(X_reg, y_reg)
print("回归预测:", cart_reg.predict()) # [140.0]import java.util.*;
public class CART {
static class Node {
Integer feature;
Double threshold;
Node left, right;
Object value; // Integer for classification, Double for regression
Node(Object value) { this.value = value; }
Node(int f, double t) { this.feature = f; this.threshold = t; }
}
private String task;
private int maxDepth;
private Node root;
public CART(String task) {
this.task = task;
this.maxDepth = 5;
}
private double gini(int[] labels) {
if (labels.length == 0) return 0;
Map<Integer, Integer> count = new HashMap<>();
for (int y : labels) count.put(y, count.getOrDefault(y, 0) + 1);
double impurity = 1.0;
for (int c : count.values()) {
double p = (double) c / labels.length;
impurity -= p * p;
}
return impurity;
}
private double mse(double[] y) {
if (y.length == 0) return 0;
double mean = Arrays.stream(y).average().orElse(0);
return Arrays.stream(y).map(v -> (v - mean) * (v - mean)).sum();
}
private Node buildTree(double[][] X, int[] y, int depth) {
if (depth >= maxDepth || y.length < 2) {
if ("classification".equals(task)) {
return new Node(majorityClass(y));
} else {
return new Node(Arrays.stream(y).average().orElse(0));
}
}
// Find best split (simplified: only first feature)
int bestFeature = 0;
double bestThreshold = 0;
double bestGain = -1;
// Sort by feature 0
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < X.length; i++) indices.add(i);
indices.sort(Comparator.comparingDouble(i -> X[i][0]));
for (int i = 0; i < indices.size() - 1; i++) {
int idx1 = indices.get(i), idx2 = indices.get(i + 1);
if (y[idx1] != y[idx2]) {
double t = (X[idx1][0] + X[idx2][0]) / 2;
// Split
List<Integer> leftIdx = new ArrayList<>(), rightIdx = new ArrayList<>();
for (int j = 0; j < X.length; j++) {
if (X[j][0] <= t) leftIdx.add(j); else rightIdx.add(j);
}
if (leftIdx.isEmpty() || rightIdx.isEmpty()) continue;
// Compute gain
double gain;
if ("classification".equals(task)) {
int[] yArr = y;
int[] yLeft = leftIdx.stream().mapToInt(j -> yArr[j]).toArray();
int[] yRight = rightIdx.stream().mapToInt(j -> yArr[j]).toArray();
gain = gini(y) - (leftIdx.size()/(double)y.length * gini(yLeft) +
rightIdx.size()/(double)y.length * gini(yRight));
} else {
double[] yD = Arrays.stream(y).mapToDouble(v -> v).toArray();
double[] yLeft = leftIdx.stream().mapToDouble(j -> yD[j]).toArray();
double[] yRight = rightIdx.stream().mapToDouble(j -> yD[j]).toArray();
gain = mse(yD) - (mse(yLeft) + mse(yRight));
}
if (gain > bestGain) {
bestGain = gain;
bestThreshold = t;
}
}
}
if (bestGain <= 0) {
if ("classification".equals(task)) {
return new Node(majorityClass(y));
} else {
return new Node(Arrays.stream(y).average().orElse(0));
}
}
// Recursively build
List<Integer> leftIdx = new ArrayList<>(), rightIdx = new ArrayList<>();
for (int i = 0; i < X.length; i++) {
if (X[i][0] <= bestThreshold) leftIdx.add(i); else rightIdx.add(i);
}
double[][] XLeft = leftIdx.stream().mapToDouble(i -> X[i]).toArray();
double[][] XRight = rightIdx.stream().mapToDouble(i -> X[i]).toArray();
int[] yLeft = leftIdx.stream().mapToInt(i -> y[i]).toArray();
int[] yRight = rightIdx.stream().mapToInt(i -> y[i]).toArray();
Node node = new Node(bestFeature, bestThreshold);
node.left = buildTree(XLeft, yLeft, depth + 1);
node.right = buildTree(XRight, yRight, depth + 1);
return node;
}
private int majorityClass(int[] y) {
Map<Integer, Integer> count = new HashMap<>();
for (int label : y) count.put(label, count.getOrDefault(label, 0) + 1);
return Collections.max(count.entrySet(), Map.Entry.comparingByValue()).getKey();
}
public void fit(double[][] X, int[] y) {
this.root = buildTree(X, y, 0);
}
// 预测方法略(递归遍历树)
}💡 完整 Java 实现需处理多特征、通用预测等,此处展示核心分裂逻辑。
CART 使用 α 参数 控制树复杂度:

优点 | 缺点 |
|---|---|
✅ 统一支持分类与回归 | ❌ 单棵树易过拟合(需剪枝或集成) |
✅ 二叉结构利于集成(RF/GBDT) | ❌ 对噪声敏感 |
✅ 天然处理连续特征 | ❌ 忽略特征间线性关系 |
✅ 可解释性强 | ❌ 不稳定(数据微变 → 树大变) |
CART 用极简的二叉设计,打通了分类与回归的任督二脉。它不仅是独立模型,更是现代集成学习的基石。
记住:在机器学习中,简单而通用的设计,往往最具生命力。
现在,你已经能:
相关链接
无论你是想写代码调用 API 的开发者,设计 AI 产品的 PM,评估技术路线的管理者,还是单纯好奇智能本质的思考者,这里都有值得你驻足的内容。 不追 hype,只讲逻辑;不谈玄学,专注可复现的认知。 让我们一起,在这场百年一遇的智能革命中,看得更清,走得更稳 https://cloud.tencent.com/developer/column/107314
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。