vLLM 是本号的常客,SGLang 写的不多,主要是我用它也不多,之前偶尔写过 SGLang 怎么用、跑什么模型,也比较浅
最近,吴恩达 DeepLearning 上最新更新了 SGLang 底层原理短课
我只学了自己感兴趣的 L2 和 L3,本文也算是学习笔记

这门课的名字叫 Efficient Inference with SGLang,由 SGLang 的作者团队和 DeepLearning.AI 联合出品
说实话,大部分人用 vLLM、SGLang 部署模型,都是 pip install 然后一行命令启动服务,能跑就行
但你有没有想过:
这门课就是来回答这些问题的
不讲废话,直接手写代码,从 Attention 公式写到 KV Cache,再到 Radix Tree,一步步把原理拆给你看,配上我自己的理解。
大语言模型生成文本是一个 token 一个 token 蹦出来的(自回归生成)
每生成一个新 token,模型都要跑一遍 Attention 机制,用当前 token 的 Query 去和所有之前 token 的 Key 做点积,算出注意力权重,再加权求和所有 Value
关键洞察来了:每个 token 的 Key 和 Value 一旦算出来就不会变
但是如果不缓存,每生成一个新 token,模型就要把之前所有 token 的 K 和 V 重新算一遍
生成 n 个 token,总计算量是 O(n²),这就是推理慢的根本原因。
课程用的是 DeepSeek-R1-Distill-Qwen 1.5B 模型,虽然小,但 Attention 架构和 70B 模型完全一样——Grouped Query Attention(GQA),所有原理直接适用于大模型。
先看核心公式:Attention(Q, K, V) = softmax(Q·K^T / sqrt(d_k)) · V
把这个公式翻译成了 Python 代码:
def_attention_impl(q, k, v, scale, mask):
# 核心:softmax(Q @ K^T / sqrt(d_k)) @ V
# Q @ K^T —— 算每对 (query, key) 的注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# 因果遮罩——未来位置变成 -inf,softmax 后变 0
scores = scores.masked_fill(~mask, float("-inf"))
# 归一化为注意力权重
probs = torch.softmax(scores, dim=-1)
# 按权重对 Value 加权求和
return torch.matmul(probs, v)
然后还有一个处理 GQA(Grouped Query Attention)的包装函数
GQA 是 DeepSeek、Llama 等现代模型的标配——多个 Query Head 共享同一组 K/V Head,比如 64 个 Query Head 共用 8 个 K/V Head,KV Cache 直接缩小 8 倍,精度损失很小
defsimple_causal_attention(query, key, value, **kwargs):
# 支持 GQA 的因果注意力
Dh = query.shape[-1]
scale = 1.0 / (Dh ** 0.5)
# GQA: 多个 Query Head 共享一组 K/V Head
gqa_group_size = query.shape[1] // key.shape[1]
key = key.repeat_interleave(gqa_group_size, dim=1)
value = value.repeat_interleave(gqa_group_size, dim=1)
# ...后续和标准 Attention 一样
课程做了一件很有意思的事:用 monkey-patch 把 PyTorch 内置的 F.scaled_dot_product_attention 换成自己写的版本,这样模型每一层都跑自己的代码
跑出来的结果和原版完全一致——token 级别一模一样,只是慢很多(纯 Python vs CUDA 内核嘛)
这是课程里最直观的实验
不用 KV Cache(朴素方式):
# 每一步都从头喂入整个序列
text_no_cache = auto_regressive_decode(
tiny_llm, input_text,
max_new_tokens, temperature=0.0
)
# 总计算量:sum(9, 10, 11, ..., 24) = 264 次 token 计算
9 个 prompt token,生成 16 个新 token。每一步都要把之前所有 token 重新过一遍模型
第 1 步处理 9 个 token,第 2 步处理 10 个……第 16 步处理 24 个,加起来 264 次 token 计算。
用 KV Cache(优化方式):
# Prefill 阶段一次性处理所有 prompt token,存下 K/V
# Decode 阶段每步只处理 1 个新 token
text_kv_cache = auto_regressive_decode_with_kv_cache(
tiny_llm, input_text,
max_new_tokens, temperature=0.0
)
# 总计算量:9 (prefill) + 15 (每步 1 个) = 24 次 token 计算
同样 9 个 prompt token + 16 个新 token
Prefill 阶段一口气处理 9 个 token,把所有 K/V 存起来
之后每步只要处理 1 个新 token,从缓存里读之前的 K/V 就行,总计算量 24 次。
264 vs 24,计算量少了 11 倍。
实测在 1.5B 模型上大约 2 倍实际加速,序列越长差距越大——1000 个 token 的输出,没有 KV Cache 需要 50 万次计算,有了 KV Cache 只需要约 1000 次
这就是"实用"和"不可用"之间的距离
而且最关键的——输出完全一样,一个 token 都不差。数学上严格等价,只是不重复做已经做过的工作
两个阶段:
一句话:*算一次,存起来,反复用。
KV Cache 解决了单个请求内的重复计算问题,但有一个更扎心的问题:
两个用户问了同一篇文档的不同问题,模型对同一篇文档的 KV 算了两遍,这是不是浪费?
答案是:当然是
在 RAG 场景里,一个 prompt 可能 90% 是文档内容(几百个 token),只有 10% 是用户问题(几十个 token)
100 个用户对同一篇文档提了 100 个问题,那就是 100 次完全相同的 KV 计算
几万个 token 白算了
SGLang 的 RadixAttention 就是来解决这个问题的
简单说,Radix Tree 就是一个按 token 序列索引 KV Cache 的树形数据结构。
课程里实现了一个简化版的 FlatRadixTree:
classCacheEntry:
# 把 token 序列和对应的 KV Cache 配对存储
def__init__(self, token_ids, kv_cache):
self.token_ids = list(token_ids)
self.kv_cache = kv_cache
classFlatRadixTree:
# 简化版基数树(线性扫描,便于理解)
def__init__(self):
self.entries = []
definsert(self, token_ids, kv_cache):
self.entries.append(CacheEntry(token_ids, kv_cache))
defsearch(self, token_ids):
# 找最长匹配前缀
best_match, best_len = None, 0
for entry inself.entries:
match_len = 0
for a, b inzip(entry.token_ids, token_ids):
if a != b:
break
match_len += 1
if match_len > best_len:
best_len = match_len
best_match = entry
return best_match
生产级 SGLang 用的是 O(log n) 的查找,课程用线性扫描,原理一模一样,就是方便你看懂
每个请求进来,经过四步:
radix.search(token_ids) 在树中查找最长匹配前缀radix.insert() 把新算出来的 KV 存回树里用代码看更清楚:
radix = FlatRadixTree() # 空树
for question in article_questions:
prompt = construct_prompt(article, question)
token_ids = tiny_llm.tokenize(prompt)
# Step 1: 搜索最长匹配前缀
prefix_cache = radix.search(token_ids)
# Steps 2 & 3: 复用缓存的 KV,只计算后缀
output, cached_req = tiny_llm.generate_with_prefix_cache(
prompt,
max_new_tokens=16,
prefix_cache=prefix_cache,
temperature=0,
)
# Step 4: 存回树里,后续请求受益
radix.insert(cached_req.token_ids, cached_req.kv_cache)
作为实验,准备了两篇 SGLang 技术文章(各约 2000 字符),每篇 6 个问题
不用 prefix caching:6 个问题每个都要从头处理整个文档 + 问题,耗时基本一样
用 RadixAttention:Q1 是冷启动(cache miss),跟不缓存一样慢。但 Q2 到 Q6 直接命中缓存——文档部分(约 97% 的 token)全部跳过,只处理问题部分的那几十个 token
实测大约 2 倍加速,总共省了约 20 秒
你可能觉得 2 倍不够震撼?那是因为实验用的文章比较短
在生产环境里,一个 RAG prompt 可能有 2000 个 token 的文档 + 10 个 token 的问题。如果 90% 的 token 都命中缓存,只需要计算 10%,那就是 10 倍 Prefill 加速
真实生产环境不会乖乖按文档分组——用户请求是随机到达的
第三个实验:两篇文章的 12 个问题随机打乱顺序
# 12 个 prompt,随机混合两篇文章的问题
random.seed(42)
random.shuffle(all_prompts)
radix_multi = FlatRadixTree()
for tag, article_name, prompt in all_prompts:
token_ids = tiny_llm.tokenize(prompt)
prefix_cache = radix_multi.search(token_ids)
# ... 生成 + 存储
结果:只有 2 次 cache miss(每篇文章的第一次请求),剩下 10 次全部命中
命中率 83%
这就是 Radix Tree 和普通单条缓存的区别——树可以同时维护多个分支,切换文档不会把另一篇的缓存踢掉
每个请求独立匹配自己的前缀,互不干扰
随着流量增长,2 次冷启动被上千个请求分摊,平均延迟趋近于 cache hit 的延迟
这就是为什么 SGLang 在生产环境里这么快
场景 | 缓存了什么 | 典型加速 |
|---|---|---|
RAG 系统 | 文档上下文 | 5-10x |
聊天机器人 | System Prompt + 对话历史 | 3-5x |
Few-Shot 学习 | 示例样本 | 4-8x |
代码生成 | 仓库上下文 | 3-6x |
核心逻辑都一样:共享前缀越长,加速越大。
L2 和 L3 解决的问题其实是一个递进关系:
两层叠加,效果是乘法级的——单请求内不浪费,跨请求也不浪费
代码模式简单到令人发指——三行搞定:
prefix_cache = radix.search(token_ids) # 搜
output = model.generate(prompt, prefix_cache=prefix_cache) # 用
radix.insert(token_ids, kv_cache) # 存
说实话,之前用 SGLang 就是跑跑 benchmark,知道它快,但不知道为什么快
这门课最大的价值是让你亲手把 KV Cache 和 Radix Tree 写一遍——写完之后,你看 SGLang 的源码就不再是天书了
推荐给两类人:
课程后面还有 L4(SGLang Diffusion,把 caching 思想用到图像生成)和 L5(SGLang Router,多引擎路由),等我学完再写。
#SGLang #KVCache #推理优化 #DeepLearningAI #大模型推理
制作不易,如果这篇文章觉得对你有用,可否点个关注。给我个三连击:点赞、转发和在看。若可以再给我加个🌟,谢谢你看我的文章,我们下篇再见!