标准Attention与Safe softmax

标准Attention

在介绍标准Attention之前,我们先定义几个基本变量:假设 batch_size 等于 1,seq_len 等于 Nemb_size 等于 d。在这个部分,我们只关注Attention的计算部分,忽略dropout、mask等其他计算。下面是标准的Attention计算流程:

S=QKTP=softmax(S)

Pasted image 20250428223029.png

标准Safe softmax

在处理浮点数时,特别是对于 float32bfloat16 类型的数据,当 x89 时,exp(x) 就会变成无穷大(inf),从而导致数据上溢的问题。为了避免数值溢出并保证数值稳定性,在计算时通常会减去最大值,这个过程称为Safe softmax。Safe softmax的计算公式如下:

首先,找出最大值 m

m=max(xi)

然后,计算Safe softmax:

softmax(xi)=eximj=1dexjm

通过这样的处理,可以有效避免数值不稳定的问题,提高计算的精度和稳定性。在实际应用中,Safe softmax是一个非常重要的技巧,尤其是在深度学习模型中涉及到概率分布的计算时。
Pasted image 20250428223037.png