处理小说、法律文件等长文本是大模型的一个重要应用方向,但也面临速度上的挑战。FlashAttention 作者 Tri Dao 等人提出的「Flash-Decoding」通过充分利用 GPU,可以将大模型的长上下文推理速度提高至 8 倍。
首先,将键 / 值分成更小的块; 使用 FlashAttention 并行计算查询与每个这些分块的注意力,为每行和每个分块额外写入一个标量值:注意力值的 log-sum-exp 最后,通过对所有分块进行归约来计算实际输出,使用 log-sum-exp 来调整每个分块的贡献。
Pytorch:使用纯粹的 PyTorch 基元来运行注意力计算(不使用 FlashAttention); FlashAttention v2; FasterTransformer:使用 FasterTransformer 的注意力内核; Flash-Decoding; 以及一个上限值,该值计算了从内存中读取整个模型和 KV-cache 所需的时间
FlashAttention 包,从 v2.2 开始:https://github.com/Dao-AILab/flash-attention/tree/main xFormers 包(搜索 xformers.ops.memory_efficient_attention),从 0.0.22 开始:调度程序将根据问题的大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个高效的 triton 内核,该内核实现了 Flash-Decoding 算法。