Flash Attention
Flash Attention是一种优化Transformer模型中注意力机制的算法,旨在提高计算效率和内存利用率。自注意力(Self-Attention)的缺点之一是计算和空间复杂度巨大,导致在处理长序列时速度变慢且内存需求巨大。
自注意力的空间复杂度
自注意力模块的空间复杂度为
GPU内存层次
- SRAM(Static RAM):静态随机存取存储器,速度快,访问延迟低,用于GPU的片上缓存
- HBM(High Bandwidth Memory):高带宽内存,用于GPU的片外内存,主要作为全局内存
- DRAM(Dynamic RAM):动态随机存取存储器,高密度、低成本,用于系统的主内存
| 内存类型 | 容量 | 带宽 | 延迟 |
|---|---|---|---|
| SRAM | 20 MB | 19 TB/s | ~20 cycles |
| HBM | 80 GB | 3.35 TB/s | ~300 cycles |
| DRAM | 512 GB | 50 GB/s | ~300 cycles |
传统自注意力
对于传统算法,
传统自注意力计算的内存访问复杂度:
- 第一步把
, 读取出来计算出 ,然后把 存回去,内存访问复杂度 - 第二步把
读取出来计算出 ,然后把 存回去,内存访问复杂度 - 第三步把
, 读取出来计算出 ,然后计算出结果 ,内存访问复杂度
综上所述,整体的内存访问复杂度为
传统自注意力面临两个问题:
- 显存占用多,由于计算过程中存储完整注意力矩阵
和 ,需要 的空间 - HBM读写次数多,需要传输的数据量大
Flash Attention原理
将从输入的
分块计算的详细图解
传统注意力 vs Flash Attention
传统注意力(需要O(N²)显存):
┌─────────────────────────────────────────────┐
│ HBM │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ Q │ │ K │ │ V │ │ O │ │
│ └──┬──┘ └──┬──┘ └──┬──┘ └──▲──┘ │
│ │ │ │ │ │
│ └────────┴────────┼────────┘ │
│ ↓ │
│ ┌─────────┐ │
│ │ S │ O(N²) │
│ │ ↓ │ │
│ │ P │ O(N²) │
│ └────┬────┘ │
└──────────────┼────────────────────────────┘
↓
[SRAM计算]
Flash Attention(O(N)显存):
┌─────────────────────────────────────────────┐
│ HBM │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ Q │ │ K │ │ V │ │ O │ │
│ └──┬──┘ └──┬──┘ └──┬──┘ └──▲──┘ │
│ │ │ │ │ │
│ └────────┴────────┼────────┘ │
│ ↓ (分块加载) │
│ ┌─────────┐ │
│ │ Q_i │ SRAM │
│ │ K_j │ │
│ │ V_j │ │
│ │ ↓ │ │
│ │ S_ij │ O(B_c × B_r) │
│ │ ↓ │ │
│ │ P_ij │ │
│ │ ↓ │ │
│ │ O_i │ │
│ └─────────┘ │
└─────────────────────────────────────────────┘分块大小选择
SRAM大小为
其中
Softmax的分块计算
对于分块计算,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的。需要注意的是,在进行softmax分块计算时,需要完整的一行作为输入数据,因为其分母需要对完整一行求和。
当处理完一个分块后,显存中只保留
增量计算过程
- 首先计算一个分块的局部softmax值,然后存储起来
- 当处理完下一个分块时,可以根据此时的新的全局最大值和全局exp求和项来更新旧的softmax值
- 当处理完所有分块后,此时的所有分块的softmax值都是"全局的"
Online Softmax更新公式:
Flash Attention伪代码
输入:在HBM中的
- 计算
、 和 、 的分块大小 和 - 初始化最终输出
, 维向量 和 - 将
、 作为外层循环, 、 作为内层循环 - 分块计算Attention score
- 计算局部
和 ,并利用其更新全局 和 - 更新输出
Flash Attention 2改进
Flash Attention 2在Flash Attention基础上进行了以下优化:
- 减少非矩阵乘法运算:将softmax的缩放和加法操作与矩阵乘法融合
- 优化并行化策略:在序列长度维度上并行,而非batch×head维度
- 优化工作分区:减少warp间的同步和通信
性能对比:
| 特性 | Flash Attention | Flash Attention 2 |
|---|---|---|
| 计算效率 | 50-73% TFLOPS | 65-73% TFLOPS |
| 内存访问 | ||
| 并行维度 | batch × head | sequence length |
| 反向传播 | 重计算 | 优化重计算 |
Flash Attention 3改进
Flash Attention 3针对Hopper架构(H100)进行了专门优化:
- 异步执行:利用Tensor Memory Accelerator(TMA)实现异步数据加载
- Warp specialization:将生产者和消费者warp分离
- FP8支持:原生支持FP8量化注意力计算
# Flash Attention 3使用示例
from flash_attn.flash_attn_interface import flash_attn_func
# 标准FP16/BF16
output = flash_attn_func(q, k, v, causal=True)
# FP8量化(H100+)
output = flash_attn_func(q, k, v, causal=True,
out_dtype=torch.float8_e4m3fn)IO复杂度分析
Flash Attention的核心优化是减少HBM访问次数:
传统注意力IO复杂度:
Flash Attention IO复杂度:
其中
Flash Attention总结
为什么加快了计算? 降低了耗时的HBM(显存)访问次数。采用Tiling技术分块从HBM加载数据到SRAM缓存进行融合计算。
为什么节省了内存? 不再对中间矩阵
, 进行存储。在反向的时候重新计算来计算梯度。 为什么是精准注意力? 算法流程只是分块计算,无近似操作。
适用场景:
- 长序列训练(序列长度 > 2048)
- 大批量训练
- 内存受限场景
限制:
- 需要GPU支持(A100/H100效果最佳)
- 分块大小受SRAM限制
- 实现复杂度较高