推理耗时
推理机制
传统推理方式:
逐 token 生成,无法并行。
过程建模两种方式:
- 矩阵-向量乘法:一个大矩阵(例如
)乘以一个向量,得到另一个向量。 - Attention 计算:利用 KV-cache 进行推理。
瓶颈分析:浮点运算的主要来源
- 矩阵-向量乘法对每个矩阵元素执行一次乘加运算(2 FLOPs)。
- Attention 对每个 key 执行一次乘加,对每个 value 执行一次乘加。
时延计算
计算一个 token 所需要的数据量
在 NVIDIA RTX 4090(1008 GB/s)上,14.2GB (fp16) 需要约 14.1ms 读取,因此可以预期对于位置靠前的 token,每个 token 大约需要 14.1ms(KV-cache 影响可以忽略不计)。如果使用 8bit 权重,需要读取 7.1GB,这需要大约 7.0ms。这些都是理论下限,代表了生成每个 token 的最小可能时间。
参考来源:《LLM inference speed of light》
通俗来说,模型的预测时间可以近似理解为:
其中
这也就是为什么众人都知 CoT 效果好,众人又都不使用 CoT(但是现在 o1、R1 的大模型推理增强还是需要很多 CoT 数据的),因为我们可以几乎下断言“模型的生成速度和生成 token 数量呈正相关”,而 CoT 恰恰又引入了大量的生成 token。
推理 TPS 计算
如何计算 TPS?
部署 LLM 时,每秒生成的 token 数量 TPS(Tokens Per Second)是衡量推理性能的重要指标:
总延迟时间包括两个阶段:
- TTFT(Time To First Token):从输入到生成第一个 token 的延迟时间,主要受 prompt 长度和模型结构影响,也就是在 Prefilling 阶段。
- TPOT(Time Per Output Token):生成每个后续 token 所需的平均时间,也就是在 Decoding 阶段。
总延迟可表示为:
所以 TPS 可以表示为:
TPS 估算方法
-
确定模型参数量:
-
计算 Prefilling 阶段的 FLOPs:
-
计算 Decoding 阶段的 FLOPs:
使用公式:
通过以上分析,我们可以更好地理解推理过程中各个阶段的性能瓶颈,并针对性地进行优化,以提升模型的推理效率。