大模型显存占用怎么算:从参数量、KV Cache 到训练开销
什么是显存
显存是位于图形卡上的高速内存。它是更大存储子系统的一个组件,有助于确保您的 GPU 能够访问所需的数据并流畅处理和显示图像。
推理显存占用的基本组成
一般情况,显存的计算公式如下:
主要分为两部分,一部分是模型参数所占用的显存,另一部分是kv_cache所占用的显存.
Part1.VRAM for model weights
这一部分是我们将模型参数加载到GPU中所需要的最小显存,是不可避免的硬性成本.
具体计算公式如下:
这个公式很好理解,==模型权重显存占用=参数量 每个参数所占的字节== .
所以我们只需要知道模型的参数量以及每个参数所占的字节即可.
那么第一个问题:如何得知模型每个参数所占的字节大小?
这通常取决于模型的数据类型也叫“精度”。我们可以在模型的config.json文件的torch_dtype字段下找到该参数。
不同精度所占用的字节情况:
- 16-bit (like
bfloat16orfloat16) = 2 bytes - 32-bit (like
float32) = 4 bytes - 8-bit (quantized) = 1 byte
- 4-bit (quantized) = 0.5 bytes
比如Qwen3-VL-8B-Instruct 的config.json 文件内容如下:
1 | { |
可以发现其精度为bfloat16,因此Qwen3-VL-8B-Instruct 每个参数所占用的字节为2Bytes .
好了,弄懂了第一个问题了,剩下的问题就是计算模型的参数量.
那么第二个问题:如何计算模型的参数量?
对于我们的Qwen3-VL-8B-Instruct模型来说,名字中的8B说明它拥有大约80亿个参数。
利用这个近似值我们可以进行非常方便的粗略估算:
所以仅仅加载这个模型就需要一张大概16GB显存的显卡。
其实到这里对于VRAM for model weights的计算已经讲清楚了。一般情况下,对于我们所要使用的模型,我们是清楚其是多少B的,模型的名称一般也都会带有多少B的字眼,我们只需要找到其
config.json知道这个模型的精度后套用公式计算即可。
不过更精确的数字是多少呢?这就需要我们对模型的参数量有一个更精确的计算。
不需要猜测,一般我们也可以从config.json 中精确计算出模型的参数数量。
由于Qwen3-VL-8B-Instruct 是VLM ,所以要分为语言模型部分和视觉编码器两部分.
语言模型部分:
模型参数总数是嵌入层、注意力模块和多层感知机(前馈)模块的参数之和。
我们可以在上面的config.json 中的text_config找到如下字段:
-
vocab_size: 151936 -
hidden_size: 4096 -
intermediate_size: 12288 -
num_hidden_layers: 36 -
num_attention_heads: 32 -
num_key_value_heads: 8 -
tie_word_embeddings: false
1.Embeddings & LM Head
由于 tie_word_embeddings 为 false,因此嵌入层和最终的语言模型头 LM Head)层是相互独立的。
2. 单层参数量:
注意力模块 GQA: Qwen3-VL-8B-Instruct 使用分组查询注意力机制 (Grouped-Query Attention, GQA)。它拥有 32 个查询头 (num_attention_heads),但只有 8 个Key, 和值 Value 头 (num_key_value_heads)。
首先,我们推导出每个头的维度 head_dim:
现在,我们计算注意力模块中的四个权重矩阵:
多层感知机模块 SwiGLU: 该模块基于 hidden_size 和 intermediate_size 使用三个矩阵。
因此:
视觉编码器部分
视觉编码器通常是一个 Vision Transformer (ViT) 架构,参数总数是图像分块嵌入、位置编码、注意力模块、多层感知机(前馈)模块和归一化层的参数之和。
我们可以在上面的 config.json 中的 vision_config 找到如下字段:
hidden_size: 1152intermediate_size: 4304num_heads: 16depth: 27patch_size: 16temporal_patch_size: 2in_channels: 3num_position_embeddings: 2304
1.Patch & Position Embeddings
首先,模型需要将输入的像素转换为向量。Qwen3-VL 包含了时间维度(视频/多帧支持),所以使用 3D 卷积或线性投影来处理 Patch。
每个 Patch 的展平维度为:
将这个 1536 维的像素块投影到 hidden_size,同时加上可学习的位置编码:
2. 单层参数量:
注意力模块 MHA: 视觉编码器使用的是标准多头注意力机制 (Multi-Head Attention, MHA),而不是文本端的 GQA。因为不是 GQA,Q、K、V 的维度完全相同。
首先,我们推导出每个头的维度 head_dim:
现在,我们计算注意力模块中的四个权重矩阵和偏置项:
多层感知机模块 GELU: 该模块基于 hidden_size 和 intermediate_size 使用两个矩阵。
LayerNorm 层: 每个 Block 通常在 Attention 前和 MLP 前各有一个 LayerNorm,每个 LayerNorm 包含缩放(weight)和平移(bias)两个参数,维度均为 hidden_size。
因此:
通过精确计算,Qwen3-VL 的视觉编码器部分(ViT)参数量约为 4.15 亿 (0.415B)。
3.模型总参数量
所以理论上模型总参数量为:
下面是我在服务器上load模型后观察到的实际显存占用:

可以计算相对误差:
可见存在4.5% 的额外显存开销.
Part2.VRAM for KV cache
这一部分是模型在推理生成过程中,为了缓存历史 token 的 Key 和 Value 所需要的显存。它和模型权重不同,不是固定不变的,而是会随着batch_size和seq_length线性增长。
具体计算公式如下:
这个公式也很好理解,==KV Cache显存占用=缓存元素数量 每个元素所占的字节== .
所以我们只需要知道KV Cache的元素数量以及每个元素所占的字节即可.
那么第一个问题:如何得知KV Cache每个元素所占的字节大小?
KV Cache通常使用模型推理时的计算精度进行存储。对于前面提到的Qwen3-VL-8B-Instruct,其config.json中的dtype为bfloat16,因此每个KV Cache元素所占用的字节为2Bytes .
需要注意的是,即使模型权重使用了量化,KV Cache也不一定会跟着量化。除非推理框架显式启用了KV Cache量化,否则一般仍然按照模型的推理精度计算。
那么第二个问题:如何计算KV Cache的元素数量?
KV Cache主要来自语言模型部分的注意力层。对于Qwen3-VL-8B-Instruct,我们可以从上面的config.json中的text_config找到如下字段:
max_position_embeddings: 262144num_hidden_layers: 36num_attention_heads: 32num_key_value_heads: 8head_dim: 128hidden_size: 4096use_cache: true
首先,我们可以再次确认每个注意力头的维度:
Qwen3-VL-8B-Instruct 使用的是 GQA(Grouped-Query Attention),也就是查询头数量和KV头数量并不相同。虽然它有 32 个 Query heads,但是只需要缓存 8 组 Key 和 Value。
因此,单层、单个token需要缓存的元素数量为:
其中的2分别代表Key和Value两个矩阵。
现在我们将其扩展到所有层、所有token以及batch维度:
如果按照单用户、最大上下文长度来估算,也就是batch_size=1,seq_length=max_position_embeddings=262144:
所以在Qwen3-VL-8B-Instruct的最大上下文长度下,单个请求的KV Cache理论上就需要约36GiB显存。
其实到这里对于VRAM for KV cache的计算也已经讲清楚了。一般情况下,我们只需要从
config.json中找到num_hidden_layers、num_key_value_heads、head_dim、max_position_embeddings以及模型精度,然后套用公式即可。实际使用时,如果输入长度没有达到最大上下文,那么KV Cache显存会按照seq_length等比例缩小;如果同时服务多个请求,则会按照batch_size等比例放大。
