

这个话题十分有趣,Jimin Lee的文章就详细解释了这一困惑,我们来看看他是怎么说的。

你是否遇到过这样的困惑:精心搭建的深度学习模型,参数仅几百MB,可一点击训练,用nvidia-smi查看,GPU显存瞬间飙升至数GB,甚至直接报OOM(显存溢出)错误。
其实这完全正常,甚至训练时显存占用和模型权重大小差不多,反而说明模型没在真正训练!核心原因很简单:推理只存权重,训练要算导数,必须留存海量中间数据。今天就拆解小模型狂吃显存的底层逻辑,再给你实用优化方案。
我们常说的“模型大小”,仅指权重和偏置,这只是训练显存占用的冰山一角。训练时,显存主要被四部分瓜分:
下面逐一算清显存账。
深度学习默认用float32(单精度浮点数),1个参数占4字节。
但训练要算梯度,每个权重对应1个梯度,梯度同样占4字节,直接翻倍显存:
这是训练的基础开销,躲不掉。
SGD很轻量,但主流的Adam/AdamW是显存“吞噬怪”——它为每个参数保存动量、方差两个状态,记录历史更新趋势,让训练更稳定。
标准Adam下,1个参数总开销:
10亿参数模型,光权重、梯度、优化器状态就占16GB,是纯权重的4倍!
前三项开销和模型参数成正比,而中间激活值才是显存飙升的关键,它和batch size、序列长度、网络深度强相关。
激活值爆炸的三大诱因:
很多小模型参数量小,却因batch size设太大,激活值占满显存。
训练结束、删除张量后,nvidia-smi仍显显存满格?不是内存泄漏,是PyTorch缓存分配器在“囤货”。
框架向系统申请/释放显存很慢,所以PyTorch用完显存不还给系统,留着下次复用。nvidia-smi显示的是“已预留显存”,而非“实际占用显存”。
想看真实占用,用PyTorch原生接口:
import torch
# VRAM actively holding tensors
allocated = torch.cuda.memory_allocated() / 1e9
# Total VRAM PyTorch is hoarding from the OS (including cache)reserved = torch.cuda.memory_reserved() / 1e9
print(f"Allocated: {allocated:.2f} GB")
print(f"Reserved : {reserved:.2f} GB")
# Value in nvidia-smi ≈ reserved
# Actual tensor usage ≈ allocated 别只会把batch size改到1,这4个现代方法更高效:
不用全用float32,激活值等用float16/bfloat16,显存直接砍半,还能提速。PyTorch代码极简:
from torch.amp import autocast, GradScaler
scaler = GradScaler('cuda')
for batch in dataloader:
optimizer.zero_grad()
# Run forward pass in FP16/BF16
with autocast(device_type="cuda", dtype=torch.float16):
output = model(batch)
loss = criterion(output, target)
# Scaled backward pass in FP32
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update() 不存所有激活值,只存少量“检查点”,反向传播时从检查点重算丢失的中间值。
Hugging Face模型一行启用:
model.gradient_checkpointing_enable()3. FlashAttention:Transformer必开
标准注意力机制显存I/O瓶颈严重,FlashAttention用分块技术优化,显存占用骤降、速度飙升。PyTorch 2.0+自动启用:
import torch.nn.functional as F
# PyTorch 2.0+ - Automatically utilizes optimized kernels like FlashAttention
out = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)
# To explicitly force and verify FlashAttention is being used:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
out = F.scaled_dot_product_attention(query, key, value) 单卡不够就多分卡:
最后划重点:
小模型狂吃显存不是bug,是深度学习的计算本质决定的。下次遇OOM,先试混合精度、激活检查点、FlashAttention,别盲目缩Batch!
资料参考:https://medium.com/@jiminlee-ai/why-your-tiny-deep-learning-model-is-hogging-all-your-gpu-vram-85bc58ee5050