Skip to content

FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling

ArXiv: 2603.05451

Pitch

As AI hardware evolves, computational resources scale unevenly—on NVIDIA's new Blackwell GPUs, tensor core throughput doubled while memory bandwidth and exponential units remained stagnant, creating new bottlenecks that invalidate previous optimization strategies. FlashAttention-4 addresses this 'asymmetric scaling' through algorithm-kernel co-design, introducing software-emulated exponentials and redesigned pipelines that collectively achieve 1.3× speedup over cuDNN and 71% theoretical utilization on B200 GPUs.


1. Executive Summary

FlashAttention-4 addresses the asymmetric hardware scaling problem on NVIDIA Blackwell GPUs (B200/GB200), where tensor core throughput doubles compared to Hopper while shared memory bandwidth and exponential units remain unchanged, creating new performance bottlenecks in attention computation. The paper introduces algorithmic innovations including software-emulated exponentials, conditional softmax rescaling, and 2-CTA MMA modes that collectively achieve up to 1.3× speedup over cuDNN 9.13 and 2.7× over Triton, reaching 1613 TFLOPs/s (71% theoretical utilization) on B200 GPUs. Beyond performance, FlashAttention-4 is implemented entirely in CuTe-DSL embedded in Python, achieving 20-30× faster compile times than prior C++ template-based approaches while maintaining full low-level hardware control.

2. Context and Motivation

The Core Problem: Asymmetric Hardware Scaling Creates New Bottlenecks

The fundamental challenge this paper tackles is that GPU hardware evolution is asymmetric—not all functional units scale at the same rate. The AI industry has rapidly transitioned from NVIDIA Hopper H100 GPUs to Blackwell B200/GB200 systems, which exhibit fundamentally different performance characteristics:

  • Tensor core throughput doubled: 2.25 PFLOPS on Blackwell vs. 1 PFLOPS on Hopper for FP16/BF16
  • Shared memory bandwidth unchanged: 128 bytes/clock/SM on both architectures
  • Exponential unit throughput unchanged: 16 ops/clock/SM on B200 (same as Hopper)

This asymmetry means that tensor cores are now "too fast" relative to other hardware units. Prior attention kernels optimized for Hopper's balanced architecture now hit bottlenecks in shared memory traffic and exponential operations rather than matrix multiplication compute.

Why This Matters: Attention Is the Computational Backbone of Modern AI

The Transformer architecture underpins nearly all modern AI applications—large language models, vision systems, and multimodal models. Within Transformers, attention is the primary computational bottleneck, exhibiting quadratic scaling in sequence length. Efficient attention directly enables:

  • Long-context reasoning: Processing multiple documents, entire codebases, or high-resolution videos
  • Cost reduction: Attention typically dominates training and inference compute for large models
  • New capabilities: Longer contexts unlock applications previously infeasible (e.g., reasoning over entire code repositories)

The paper's roofline analysis (Sections 3.1.1 and 3.2.1) reveals that on Blackwell, shared memory traffic and exponential operations now dominate execution time by 25-60% over MMA compute. This is a qualitative shift from prior generations where tensor core throughput was the primary constraint.

Where Existing Approaches Fall Short

FlashAttention-3 targets Hopper, not Blackwell. FlashAttention-3 (Shah et al., 2024) optimized attention for H100 GPUs through asynchronous execution and warp specialization. However:

  • It primarily targets Hopper architecture and cannot leverage Blackwell's new features
  • Hopper MMA instructions lack forward compatibility—simply porting the code doesn't work
  • Hopper's accumulator tiles (64×128) are smaller than Blackwell's (128×128), requiring algorithmic redesign

Hopper-optimized pipelines leave performance on the table. FlashAttention-3 stores accumulators in registers with an interleaved pattern across threads. Blackwell's tensor cores write directly to tensor memory (TMEM)—a fundamentally different execution model that requires new pipeline designs to exploit.

Non-MMA resources emerge as bottlenecks. Prior work assumed tensor core throughput was the limiting factor. This paper demonstrates that on Blackwell:

"Shared memory traffic and exponential operations now dominate execution time, exceeding MMA compute by 25-60%."

This requires explicitly optimizing for non-matmul units—something prior attention kernels did not systematically address.

How This Paper Positions Itself

The paper frames itself as algorithm-hardware co-design rather than straightforward optimization. Instead of treating the GPU as a uniform compute resource, the authors:

  1. Identify shifting bottlenecks through roofline analysis of each hardware unit's throughput
  2. Design algorithms to mitigate non-MMA bottlenecks: software-emulated exponentials (using FMA units in parallel with MUFU), conditional softmax rescaling, and 2-CTA MMA modes to reduce shared memory traffic
  3. Exploit new Blackwell features: tensor memory (TMEM) for accumulator storage, 128×128 MMA tiles, and fully asynchronous tensor core operations

The paper also positions itself against low-precision attention approaches (SageAttention series) by noting these primarily target consumer GPUs, while most AI compute runs on datacenter GPUs where Blackwell is deployed.


3. Technical Approach

3.1 Reader Orientation

FlashAttention-4 is a GPU kernel implementation for the attention mechanism that redesigns both the forward and backward pass algorithms to exploit Blackwell GPU hardware characteristics—specifically addressing the mismatch between tensor core throughput and other functional units.

3.2 Big-Picture Architecture

The system consists of two main computational kernels (forward pass and backward pass), each with specialized pipeline designs:

  1. Forward Pass Kernel: Processes query-key-value inputs to compute attention outputs. Uses a ping-pong pipeline with two warpgroups alternating between MMA operations and softmax computation, plus a correction warpgroup for output rescaling decoupled from the critical path.

  2. Backward Pass Kernel: Computes gradients for training. Uses 2-CTA MMA mode where a CTA pair cooperatively executes MMAs, reducing shared memory traffic and halving global atomic reductions for dQ accumulation.

  3. Software-Emulated Exponential Module: A polynomial approximation of \(2^x\) that runs on FMA units in parallel with hardware MUFU.EX2 operations, increasing effective exponential throughput.

  4. Scheduler Module: Implements longest-processing-time-first (LPT) scheduling for load balancing across SMs, particularly important for causal masking and variable sequence length scenarios.

  5. CuTe-DSL Implementation Layer: The entire kernel is written in Python-embedded CuTe-DSL, compiled via JIT to PTX then SASS, enabling 20-30× faster compile times than C++ templates.

3.3 Roadmap for the Deep Dive

The explanation proceeds in the following order:

  1. Roofline analysis methodology: How the authors identify bottlenecks through cycle-level modeling of hardware unit throughput
  2. Forward pass innovations: The redesigned pipeline, exponential emulation, and conditional softmax rescaling
  3. Backward pass innovations: 2-CTA MMA mode, TMEM usage, and reduced atomic operations
  4. Scheduling optimizations: LPT scheduling for causal masking and variable sequence lengths
  5. Implementation framework: CuTe-DSL and the productivity gains from Python-embedded kernel development

3.4 Detailed Technical Breakdown

This is an algorithm and systems paper focused on GPU kernel optimization. The core idea is that hardware scaling asymmetry requires fundamentally rethinking which operations to optimize.


Roofline Analysis: Identifying the Bottlenecks

The paper develops a cycle-accurate cost model for each hardware resource to identify bottlenecks. This analysis precedes the algorithm design and directly motivates all optimization decisions.

Throughput specifications for Blackwell B200:

Hardware Unit Throughput Notes
Tensor cores (BF16 MMA) 8192 ops/clock/SM Derived from 2.25 PFLOPS ÷ 1850 MHz ÷ 148 SMs
Exponential unit (MUFU) 16 ops/clock/SM Unchanged from Hopper
Shared memory read 128 bytes/clock/SM Unchanged from Hopper

The analysis parameterizes tile dimensions: \(M\) is the sequence length tile dimension for Q, \(N\) is the sequence length tile dimension for K/V, and \(d\) is the head dimension.

Forward pass cost model:

MMA compute time (two MMAs per iteration: \(QK^\top\) and \(PV\)):

\[T_{MMA} = \frac{4MN d}{8192} \text{ cycles}\]

The factor of 4 accounts for 2 MMAs × 2 FLOPs per MMA (multiply and accumulate).

Shared memory traffic (different for SS and TS operations):

\[T_{smem} = \frac{3MN d}{8192} \text{ cycles}\]

This formula assumes \(M, N, d\) are multiples of 128. The derivation accounts for: - QK\(\boldsymbol{^\top}\) (SS): \(\lceil M/128 \rceil \times \lceil N/128 \rceil \times 256d\) elements loaded from shared memory - PV (TS): \(\lceil M/128 \rceil \times \lceil d/128 \rceil \times 128N\) elements loaded from shared memory

Exponential unit time:

\[T_{exp} = \frac{MN}{16} \text{ cycles}\]

Table 1 summarizes results for two tile configurations:

Resource \(128^3\) (cycles) \(256 \times 128^2\) (cycles)
MMA compute 1024 2048
Shared memory 768 1536
Exponential unit 1024 2048

For the larger tile size, MMA compute and exponential unit are co-equal bottlenecks at 2048 cycles, while shared memory is lower at 1536 cycles. This motivates focusing on exponential throughput optimization.

Backward pass cost model:

The backward pass performs five MMA operations per iteration (recomputing S, and computing dV, dP, dQ, dK). The total MMA time:

\[T_{MMA} = \frac{10MN d}{8192} \text{ cycles}\]

Shared memory traffic is higher due to multiple operand loads and intermediate writes:

\[T_{smem} = \frac{4M d + 3N d + MN}{64} + \frac{MN}{64} + \frac{M d}{16} \text{ cycles}\]

For \(M = N = d = 128\): shared memory totals 3328 cycles, exceeding MMA compute at 2560 cycles by ~30%. This identifies shared memory bandwidth as the primary backward pass bottleneck.


Forward Pass Pipeline Design

The forward pass computes:

\[S = \alpha QK^\top \in \mathbb{R}^{N \times N}$$ $$P = \text{softmax}(S) \in \mathbb{R}^{N \times N}$$ $$O = PV \in \mathbb{R}^{N \times d}\]

where \(\alpha = 1/\sqrt{d}\) is the scaling factor.

Key architectural differences from Hopper:

  1. Tensor Memory (TMEM): Blackwell introduces 256 KB of TMEM per SM specifically for tensor core outputs. Unlike registers, TMEM is warp-synchronous and tightly coupled with tensor cores, allowing MMA outputs to write directly without consuming register file bandwidth.

  2. Larger accumulator tiles: Blackwell MMA tiles are 128×128 (double Hopper's 64×128 area).

  3. Fully asynchronous MMAs: Blackwell MMAs write to TMEM asynchronously, enabling better overlap between computation and other operations.

Ping-pong pipeline structure:

The pipeline processes two output tiles per thread block, alternating between: - One tile's tensor core operations: MMA for \(QK^\top\) and \(PV\) - Other tile's softmax computation: Row-wise max, exponentials, normalization

This design maximizes overlap—the critical observation is that softmax and MMA can proceed concurrently if properly staged.

Thread organization:

Two warpgroups of 128 threads each, with each thread processing an entire row. This eliminates inter-warp shuffles for row-max reduction (each thread owns a complete row). The paper explicitly notes:

"This eliminates the need for inter-warp shuffles to reduce the row max, and for multiple statistics registers per thread."

A correction warpgroup handles output rescaling separately from the softmax warpgroups, decoupling this operation from the critical path—a new feature compared to FlashAttention-3.

Tensor memory partitioning:

The pipeline must allocate two tiles worth of output (256 KB TMEM ÷ 2 bytes/element for BF16 = 128K elements, which is 2 × 128×128 tiles). For the remaining TMEM (roughly half), two partitioning options exist: - One tile of S and two tiles of P - Two tiles of S that overlap with P

The paper chooses the latter because it allows immediate computation of two S tiles at pipeline start, improving throughput.

Register pressure management:

A practical challenge: each thread must hold an entire row of 128 elements in registers. For BF16: - 128 registers for input row - 64 registers for output (plus temporary registers)

To reduce pressure, the implementation stages the P storage: the first three quarters trigger MMA operations, while the last quarter is stored separately.


Software-Emulated Exponential Function

The exponential bottleneck:

The multi-function unit (MUFU) computes exponentials at 16 ops/clock/SM—512× slower than tensor cores (8192 ops/clock/SM). Softmax requires \(M \times N\) exponential evaluations per tile, making this a critical bottleneck.

Polynomial approximation approach:

The key insight is to decompose the exponential using range reduction:

\[2^x = 2^{\lfloor x \rfloor} \cdot 2^{x - \lfloor x \rfloor}\]

where \(\lfloor x \rfloor\) is the integer part and \(x - \lfloor x \rfloor \in [0, 1)\) is the fractional part.

Integer part computation: Since IEEE 754 floating-point representation stores the exponent field as a power of two, computing \(2^{\lfloor x \rfloor}\) is a shift-and-add operation on exponent bits using integer ALU instructions.

Fractional part approximation: For \(x_{frac} \in [0, 1)\), approximate using a polynomial:

\[2^{x_{frac}} \approx \sum_{i=0}^{n} p_i \cdot x_{frac}^i\]

Coefficients are optimized using the Sollya software package to minimize relative error. Evaluation uses Horner's method with FMA instructions.

Complete algorithm:

  1. Clamp \(x \geq -127\) to avoid underflow
  2. Compute \(\lfloor x \rfloor\) using round-down mode: add \(2^{23} + 2^{22}\) to force fractional bits into mantissa, then subtract back
  3. Compute fractional part: \(x_{frac} = x - \lfloor x \rfloor\)
  4. Evaluate polynomial for \(2^{x_{frac}}\)
  5. Combine: shift \(\lfloor x \rfloor\) into exponent field and add mantissa bits of \(2^{x_{frac}}\)

Partial emulation strategy:

Using emulation for all exponentials would increase register pressure (coefficients, intermediate values) and could cause spills negating throughput gains. The paper applies emulation to only 10-25% of entries per softmax row, with the rest computed via hardware MUFU.EX2. The exact fraction is tuned empirically based on the MMA/exponential throughput ratio.

Numerical accuracy (Table 2):

Against FP64 reference on 4M random inputs in [0, 1):

Method Max rel error (FP32) Max rel error (after BF16)
Hardware MUFU.EX2 \(1.41 \times 10^{-7}\) \(3.89 \times 10^{-3}\)
Degree 3 polynomial \(8.77 \times 10^{-5}\) \(3.90 \times 10^{-3}\)
Degree 5 polynomial \(1.44 \times 10^{-7}\) \(3.89 \times 10^{-3}\)

The key finding: BF16 quantization error dominates (~3.9×10⁻³), making all polynomial degrees ≥3 essentially equivalent to hardware accuracy after rounding to BF16. A degree-3 polynomial matches hardware to within 1 BF16 ULP on 99% of inputs.


Conditional Softmax Rescaling

Online softmax background:

FlashAttention processes attention in blocks, maintaining running statistics for numerical stability:

\[m_j = \max(m_{j-1}, \text{rowmax}(S_j))$$ $$\ell_j = e^{m_{j-1} - m_j}\ell_{j-1} + \text{rowsum}(e^{S_j - m_j})\]

The output update rescales previous results:

\[O_j = e^{m_{j-1} - m_j}O_{j-1} + e^{S_j - m_j}V_j\]

Two key observations:

  1. Rescaling is only necessary when \(m_j > m_{j-1}\) (new larger values found)
  2. We can tolerate "slack"—only rescale when \(m_j - m_{j-1} > \tau\)

The threshold \(\tau = \log_2(256) = 8.0\) corresponds to a rescaling factor of 256.0.

Modified algorithm:

\[O_j = \begin{cases} e^{m_{j-1} - m_j}O_{j-1} + e^{S_j - m_j}V_j & \text{if } m_j - m_{j-1} > \tau \\ O_{j-1} + e^{S_j - m_{j-1}}V_j & \text{otherwise} \end{cases}\]

When the threshold isn't exceeded, skip updating \(m\) and use \(m_{j-1}\). Correctness is maintained because final normalization uses the true maximum \(m_{final}\) and normalizer \(\ell_{final}\):

\[\text{Output} = \frac{1}{\ell_{final}}O_{final}\]

This significantly reduces vector multiplications while maintaining numerical accuracy. To avoid warp divergence, rescaling occurs when any thread in the warp needs it.


Backward Pass Pipeline and 2-CTA MMA Mode

The backward pass computes five MMAs per iteration:

  1. \(S = \alpha QK^\top\) (recompute)
  2. \(dV = P^\top dO\)
  3. \(dP = dO V^\top\)
  4. \(dS = d\text{softmax}(dP)\) (element-wise)
  5. \(dQ = \alpha dS K\)
  6. \(dK = \alpha dS^\top Q\)

TMEM enables new scheduling:

FlashAttention-3 stored accumulators in registers, imposing ordering constraints that serialized the compute graph. With TMEM: - Multiple accumulator tiles can coexist - Better overlap between MMA and non-MMA operations - The softmax computation overlaps with dQ and dK MMAs from the previous iteration

TMEM allocation strategy:

At most four 128×128 tiles fit in TMEM. The implementation: - S and P share TMEM block at offset 0 - dP, dS, and dQ share another TMEM block

This enables computing dS (element-wise) for the current tile in parallel with the dQ MMA from the previous tile.

2-CTA MMA mode for shared memory reduction:

Blackwell supports a mode where two CTAs within the same thread block cluster cooperatively execute a single MMA. With tile shape \(M = 256, N = K = 128\):

  • Each CTA stages half of operand B in its own shared memory
  • The accumulator is partitioned across CTAs in the M dimension
  • Hardware consumes the combined B tile during multiply

Impact on shared memory traffic:

Table 3 shows for \(M = N = d = 128\):

Resource 1-CTA cycles 2-CTA cycles
MMA compute 2560 2560
Shared memory (MMA operands) 2048 1536
Shared memory (dS write) 256 256
Shared memory (dS DSMEM) 0 384
Shared memory (dQ write + read) 1024 512
Total shared memory 3328 2688
Exponential unit 1024 1024

2-CTA mode reduces shared memory traffic from 3328 cycles to 2688 cycles (~19% reduction).

dQ step restructuring (Figure 3):

The standard reduction axis conflicts with 2-CTA splitting. The solution: - Use distributed shared memory (DSMEM) to exchange half of dS between CTAs - Repack dS partitioned along the non-reduction axis - Each CTA owns \(M/2\) rows and holds the full \(2N\) reduction - Per-CTA dQ MMA tile shape: \((M/2, 2N) \times (2N, d)\) → accumulates \((M/2, d)\)

Halving global atomic reductions:

Each CTA writes only half of the dQ tile, performing half as many global atomic reductions. Atomic updates are expensive (memory fences, contention), so this is a significant improvement.


Deterministic Backward Pass

Standard backward passes introduce nondeterminism through inter-CTA reductions in global memory (affecting dQ generally, and dK/dV for GQA). For reproducible training (e.g., RL applications), a deterministic mode is provided.

Lock-based serialization:

Each CTA must acquire a semaphore lock in predefined order before reducing. This introduces overhead from: 1. Memory fences for device-wide visibility 2. Stalls waiting for previous CTAs

Shortest-processing-time-first (SPT) scheduling:

For causal masking, the implementation: - Launches KV blocks in descending order - Traverses query blocks in ascending order from the diagonal - Orders dQ reductions by descending query block index

This "shortest-processing-time-first" approach minimizes stalls—no CTA waits on its first dQ write.

Performance impact (Figure 7):

The deterministic backward achieves up to 75% the speed of the nondeterministic 1-CTA backward pass, a significant improvement over naive locking approaches.


Scheduling Optimizations

Load imbalance problem:

Attention kernels are naturally load-imbalanced: - Causal masking: Worktiles above the diagonal require fewer operations (masked out) - Variable sequence length (varlen): Different batches have different context lengths

Longest-processing-time-first (LPT) scheduling:

Classical result from parallel processing theory: schedule longest tasks first to minimize makespan.

For causal masking: - Standard order processes worktiles from shortest to longest (inefficient) - LPT order: traverse mblocks in reverse order - Batch dimension remains outermost to preserve L2 cache locality - Head dimension is swizzled to avoid L2 cache thrashing

Empirical impact:

"For BF16 and head dimension 128 we obtain 4-8% FLOPS gain for MHA and 7-14% for MQA 8 as measured on an H200 GPU."

For varlen: a preprocessing kernel sorts batches by estimated execution time, creating a virtual-to-actual batch index mapping. This metadata is cached, avoiding runtime overhead.


CuTe-DSL Implementation Framework

The compilation bottleneck:

FlashAttention-2 and FlashAttention-3 used complex C++ template metaprograms requiring compilation of hundreds of kernel variants. Table 4 shows:

Method Forward compile Backward compile
FlashAttention-3 55s 45s
FlashAttention-4 2.5s 1.4s
Speedup 22× 32×

CuTe-DSL approach:

  • Embedded in Python, lowering to PTX then SASS via JIT compilation
  • Programming model isomorphic to CUTLASS C++
  • Direct PTX access as an escape hatch for custom operations
  • Preserves full expressivity of low-level GPU programming

Productivity benefits:

The paper notes that developers with just months of GPU programming experience have successfully built FlexAttention and block-sparse variants without modifying the core framework. The modular design factors common functionality (masking, varlen, scheduling) into composable primitives.


4. Key Insights and Innovations

Innovation 1: Algorithm-Hardware Co-Design for Asymmetric Scaling

The most fundamental contribution is the recognition that hardware scaling is asymmetric and requires explicit optimization for non-MMA units. Prior work assumed tensor core throughput was the universal bottleneck. The paper's roofline analysis demonstrates that shared memory bandwidth and exponential throughput now exceed MMA compute by 25-60% on Blackwell.

This is significant because it changes the optimization strategy from "maximize tensor core utilization" to "balance all hardware resources." The paper systematically addresses each bottleneck:

  • Exponential bottleneck: Software emulation using FMA units
  • Shared memory bottleneck: 2-CTA MMA mode halving operand loads
  • Rescaling overhead: Conditional rescaling skipping unnecessary operations

The approach is a template for future kernel development as hardware asymmetry continues to increase.

Innovation 2: Software-Emulated Exponential Function with Polynomial Approximation

Using FMA units to emulate exponentials is technically novel for attention kernels. The key insight is that BF16 quantization error dominates polynomial approximation error, so a degree-3 polynomial is essentially equivalent to hardware MUFU.EX2 after rounding.

This is not an incremental optimization—it fundamentally increases effective exponential throughput by distributing computation across two hardware units (MUFU and FMA) that can operate in parallel. The partial emulation strategy (10-25% of entries) is a practical innovation balancing throughput gains against register pressure.

Innovation 3: Conditional Softmax Rescaling with Threshold-Based Skipping

Online softmax rescaling is a well-known technique, but the observation that rescaling can be conditionally skipped with a threshold \(\tau\) is novel. The insight that final normalization corrects any intermediate slack is mathematically elegant:

"As long as we keep track of the statistics (the total scaling we have done), we can still get the true denominator at the end to get the right final output."

This reduces vector multiplications—a non-trivial operation when the attention matrix has \(N^2\) elements—while maintaining numerical stability.

Innovation 4: 2-CTA MMA Mode with DSMEM Exchange for Backward Pass

Using Blackwell's 2-CTA MMA mode to restructure the dQ computation is a sophisticated systems innovation. The DSMEM-based exchange of dS tiles enables each CTA to form the correct operand shape for doubled reduction while halving shared memory traffic.

The secondary benefit—halving global atomic reductions—addresses a practical bottleneck (atomics introduce nondeterminism and contention). This is a systems-level insight that required deep understanding of both the algorithm (which operations can be restructured) and hardware (how 2-CTA mode works).

Innovation 5: CuTe-DSL Framework for Productive Kernel Development

While not an algorithmic innovation, implementing FlashAttention-4 entirely in Python-embedded CuTe-DSL with 20-30× faster compile times is significant for the research community. Kernel development has historically required C++ template metaprogramming expertise—a high barrier to entry.

The paper demonstrates that productivity and performance are not mutually exclusive. The CuTe-DSL kernels achieve state-of-the-art performance while being more accessible to researchers. The modular primitive design enables rapid development of attention variants (FlexAttention, block-sparse) without sacrificing optimization.


5. Experimental Analysis

Evaluation Methodology

Hardware platform: - NVIDIA B100 180GB SXM6 (1000W) GPU - Benchmark conducted on B200 architecture (paper title/abstract reference)

Software versions (Appendix A.1): - CUDA 13.1 - FlashAttention 2.8.3 - Triton 3.6 - PyTorch 2.10.0 - CuTe-DSL 4.4.1 - cuDNN 9.13 (main comparison) and 9.19.1.2 (latest at time of writing)

Benchmark protocol: - 5 warmup runs, then 10 timed runs averaged - Varied sequence lengths: 1k, 2k, ..., 32k - Batch size set so total tokens = 32k - Hidden dimension: 2048 - Head dimensions: 64, 128, and (192, 128) for DeepSeek V3 configuration

Baselines: - PyTorch: Standard attention implementation - FlashAttention-2: Prior-generation IO-aware attention - Triton: B200-specific instructions - Gluon: Lower-level GPU programming language - cuDNN 9.13: Vendor library optimized for B200

Metrics: - TFLOPs/s calculated as: - Forward: \(4 \cdot \text{seqlen}^2 \cdot \text{head\_dim} \cdot \text{num\_heads}\) (divide by 2 for causal) - Backward: Forward FLOPs × 2.5 (accounts for 5 MMAs vs. 2 in forward)


Main Quantitative Results

Forward Pass Performance (Figures 4 and 5)

Non-causal attention (Figure 4, left): - FlashAttention-4 achieves 1.1-1.3× speedup over cuDNN 9.13 - FlashAttention-4 achieves 2.1-2.7× speedup over Triton - Peak performance: ~1600 TFLOPs/s (~71% of theoretical maximum) - Consistent gains across all sequence lengths 1k-32k

Causal attention (Figure 4, right): - Similar speedups: 1.1-1.3× over cuDNN - 2.1-2.7× over Triton - Larger gains attributed to LPT scheduler optimization

DeepSeek V3 configuration (192, 128) (Figure 5): - FlashAttention-4 maintains advantages over cuDNN for causal attention - Demonstrates generality across different head dimension configurations

The paper notes:

"Since the initial release of our implementation, newer versions of cuDNN have incorporated many of the techniques described in this paper, yielding similar performance to FA4."

This validates the paper's contributions—vendor libraries have adopted the techniques.

Backward Pass Performance (Figure 6)

Non-causal (Figure 6, left): - FlashAttention-4 consistently outperforms cuDNN across sequence lengths - Demonstrates 2-CTA mode effectiveness for shared memory reduction

Causal (Figure 6, right): - Similar speedups maintained - Gains consistent across 4k+ sequence lengths

Deterministic Backward Pass (Figure 7)

Performance relative to nondeterministic 1-CTA backward: - Naive approach: Significant degradation (no swizzling) - LPT with swizzling: Improved performance - SPT scheduling: Best deterministic performance at ~75% of nondeterministic speed

Figure 8 shows additional ablations confirming LPT scheduling benefits for both causal and non-causal cases.


Assessment: Do the Experiments Support the Claims?

Claim 1: FlashAttention-4 achieves 1.3× speedup over cuDNN 9.13. Supported. Figures 4 and 5 show consistent speedups of 1.1-1.3× across configurations. The use of cuDNN 9.13 as the primary baseline is fair—it represents the vendor's best effort. The caveat that newer cuDNN versions have incorporated these techniques actually strengthens the contribution's validation.

Claim 2: 2.7× speedup over Triton. Supported. The Triton baseline uses B200-specific instructions, making this a meaningful comparison. The larger gap (vs. cuDNN) suggests Triton's abstraction layer may not fully exploit Blackwell's features.

Claim 3: 71% theoretical utilization (1613 TFLOPs/s). Supported with context. This is measured performance. However, the paper's own roofline analysis shows the theoretical maximum is constrained by shared memory and exponential bottlenecks, not tensor core throughput. The 71% figure should be interpreted as "% of theoretical tensor core peak," not "% of achievable peak given asymmetric constraints."

Claim 4: 20-30× faster compile times. Strongly supported. Table 4 provides precise numbers: 22× for forward, 32× for backward. This is a direct, unambiguous measurement.


Ablation Studies and Robustness Checks

Deterministic backward ablations (Figures 7 and 8): - Compare naive, LPT, and SPT scheduling strategies - Validate that SPT achieves best deterministic performance - Show batch/head swizzling provides consistent benefits

Compile time comparison (Table 4): - Direct comparison between FlashAttention-3 (C++) and FlashAttention-4 (CuTe-DSL) - Single-kernel compilation times

LPT scheduling validation: - Measured on H200 GPU (not B200) - 4-8% FLOPS gain for MHA, 7-14% for MQA-8 - Confirms scheduling benefits generalize across architectures


Limitations and Considerations

Single GPU architecture. All benchmarks are on B200/B100. While the asymmetric scaling trend is general, specific cycle counts and optimal tile sizes may differ on future architectures.

BF16 focus. The exponential emulation accuracy analysis specifically targets BF16 precision. For FP32 or FP8, different polynomial degrees or strategies may be needed.

No comparison to FlashAttention-3 on Hopper. The paper notes FA-3 doesn't run on B200, but a head-to-head comparison on H100 would contextualize the architectural differences.

Compile time comparison is single-kernel. FlashAttention implementations typically precompile hundreds of kernel variants. The 20-30× figure applies per kernel—total compilation savings depend on variant count.

Theoretical maximum interpretation. The 71% utilization is relative to tensor core peak. Given the asymmetric constraints, the "true" peak given shared memory/exp bottlenecks may be lower, making actual efficiency higher than 71% of achievable peak.

6. Limitations and Trade-offs

Assumption: Asymmetric Scaling Trend Continues

The paper's core thesis is that tensor core throughput scales faster than other functional units, creating bottlenecks in shared memory bandwidth and exponential operations. While the roofline analysis clearly demonstrates this for Blackwell (Section 3.1.1), the approach assumes this asymmetry persists in future architectures. If NVIDIA rebalances hardware—doubling shared memory bandwidth or exponential unit throughput on future GPUs—the bottleneck profile would shift, potentially altering which optimizations remain beneficial. The paper acknowledges this implicitly by noting:

"Although optimized for Blackwell GPUs, some of these algorithms can be extended to other accelerators as compute continues to outpace non-matmul units."

However, the specific cycle-count optimizations (tile sizes, emulation percentages) are tuned for Blackwell's exact throughput ratios and would require re-calibration.

BF16-Specific Accuracy Analysis

The exponential emulation accuracy study (Table 2) specifically targets BF16 precision, where quantization error (~3.9×10⁻³) dominates polynomial approximation error. This justifies using degree-3 polynomials. However, for FP32 precision, the hardware MUFU.EX2 achieves ~1.4×10⁻⁷ max relative error while degree-3 emulation is ~600× worse at 8.8×10⁻⁵. The paper does not analyze whether higher-degree polynomials would be necessary for FP32 workloads, nor the throughput/accuracy tradeoff in that regime.

Similarly, for emerging FP8 precision (which Blackwell supports), the quantization error landscape differs, and the paper does not analyze whether the same emulation strategy applies. This is notable given the paper's citation of SageAttention3's FP4 quantization work on Blackwell consumer GPUs.

Overhead of Partial Emulation Not Quantified

The partial emulation strategy (10-25% of entries per row) is empirically tuned, but the paper does not provide: - Sensitivity analysis of optimal emulation percentage across tile sizes - Register pressure measurements with and without emulation - Latency comparison between MUFU.EX2 and polynomial evaluation

The paper states that full emulation "would increase register pressure and could cause spills that negate the throughput benefit," but does not quantify this tradeoff. A microbenchmark showing throughput vs. emulation percentage would strengthen this design choice.

No Analysis of Edge Cases in Conditional Softmax Rescaling

The conditional rescaling algorithm skips updates when \(m_j - m_{j-1} \leq \tau\) with \(\tau = 8.0\). The paper does not analyze: - Adversarial cases: What if scores are designed to trigger the threshold repeatedly? The paper notes rescaling occurs "when any thread in the warp needs rescaling" to avoid divergence, but worst-case behavior could approach standard rescaling overhead. - Interaction with mixed-precision training: If intermediate accumulators overflow due to skipped rescaling, does BF16 provide sufficient headroom given \(\tau = 8.0\) corresponds to a factor of 256×?

The correctness argument relies on final normalization, but intermediate accumulator overflow risk is not discussed.

2-CTA Mode Requires Thread Block Cluster Support

The 2-CTA MMA optimization requires CTAs to be launched in fixed pairs within the same thread block cluster (Section 2.2). This imposes scheduling constraints on the overall kernel launch configuration: - Not all attention variants can restructure to exploit 2-CTA mode - The DSMEM exchange requires peer CTAs to be co-scheduled on the same GPC

The paper does not discuss scenarios where 2-CTA mode is infeasible (e.g., workloads with odd tile counts, head dimensions that don't align with 256 M-dimension tiles). For practical deployment, fallback to 1-CTA mode is implied but not explicitly described.

Deterministic Mode Overhead Not Fully Characterized

Figure 7 shows deterministic backward achieving ~75% of nondeterministic 1-CTA speed, but this is relative to the 1-CTA baseline. The paper does not report deterministic overhead relative to the 2-CTA mode, which is the primary high-performance configuration. If deterministic 2-CTA performance is significantly worse, practitioners may need to choose between reproducibility and peak throughput—tradeoff not quantified.

Additionally, the SPT scheduling for causal masking requires launching KV blocks in descending order and traversing query blocks from the diagonal. The paper does not analyze: - Non-causal cases: Does deterministic mode have higher overhead without the diagonal structure to exploit? - Varlen interactions: How does batch sorting interact with deterministic scheduling?

Compile Time Comparison Is Per-Kernel

Table 4 shows 22× and 32× compile time reductions for forward and backward kernels respectively. However: - FlashAttention implementations typically precompile hundreds of kernel variants for different sequence lengths, head dimensions, and data types - The paper does not report total compilation time for a full kernel library - The 20-30× figure may under- or over-state practical gains depending on how many variants are needed

The productivity benefits (Section 4) note that "developers have successfully built FlexAttention and block-sparse attention variants," suggesting the framework enables rapid iteration, but end-to-end build time comparison would be more actionable.

No Comparison to Vendor Libraries on Hopper

The paper benchmarks against cuDNN 9.13 on B200, but does not compare FlashAttention-3 (Hopper-optimized) against cuDNN on H100. This makes it difficult to assess: - Whether the algorithmic innovations (exponential emulation, conditional rescaling) would benefit Hopper architectures - The relative generational improvement from Hopper-optimized to Blackwell-optimized kernels

The LPT scheduling section notes gains measured on H200 (4-8% for MHA, 7-14% for MQA-8), suggesting some techniques generalize, but a systematic comparison is absent.

Roofline Analysis Omissions

Section 3.1.1 explicitly notes:

"We note that this is a simplified analysis that does not consider all resources in the GPU (e.g., floating point math, register bandwidth, L2 bandwidth)."

This simplification is pragmatic but potentially significant: - Register bandwidth: The forward pass pipeline requires holding 128-element rows in registers. Register file bandwidth could become limiting if emulation increases register pressure. - L2 cache: The scheduling optimizations (LPT, head swizzling) are designed around L2 cache behavior, but L2 bandwidth is not included in the bottleneck analysis.

A more complete model might reveal secondary bottlenecks as tensor core efficiency improves.


7. Implications and Future Directions

A Template for Asymmetric Hardware Optimization

This work establishes that algorithm-hardware co-design must account for non-uniform scaling. As accelerators continue to prioritize matrix multiplication throughput (the primary workload for deep learning), other functional units will increasingly become bottlenecks. The paper provides a systematic methodology:

  1. Roofline analysis per hardware unit: Model throughput of tensor cores, shared memory, exponential units, and other specialized units independently
  2. Identify the bottleneck delta: Quantify how much faster tensor cores are relative to other units
  3. Redesign algorithms to balance: Use software emulation, mode selection, and scheduling to mitigate non-MMA constraints

This approach is directly applicable to future architectures—Blackwell B300/GB300 have already doubled exponential throughput to 32 ops/clock/SM (noted in Section 2.2), which would change the bottleneck profile and require re-optimization.

Research Directions Enabled

Extending emulation to other nonlinear functions. The polynomial approximation technique for \(2^x\) can generalize to other functions: GELU, SiLU, layer normalization, and other operations that rely on specialized hardware units. If these units also scale slower than tensor cores, emulation strategies become necessary. A systematic study of precision/throughput tradeoffs for common nonlinearities would be valuable.

Compiler-assisted bottleneck mitigation. The paper manually identifies and addresses bottlenecks. A compiler or auto-tuner could: - Automatically detect hardware unit ratios from device properties - Generate emulation code for bottleneck functions - Tile sizes and scheduling strategies optimized for specific architectures

CuTe-DSL's JIT compilation infrastructure provides a foundation for such automation.

Cross-architecture portability. The paper notes some techniques (LPT scheduling, conditional rescaling) generalize across architectures. A systematic study of which optimizations are architecture-agnostic vs. Blackwell-specific would help kernel developers prioritize efforts. The scheduling gains measured on H200 (4-14% improvement) suggest broader applicability.

Alternative attention mechanisms. The framework specifically targets standard scaled dot-product attention. Alternative mechanisms—linear attention, multi-query attention (MQA), grouped-query attention (GQA), and sliding-window attention—have different compute/memory ratios. FlashAttention-4's support for DeepSeek V3's (192, 128) configuration (Figure 5) demonstrates extensibility, but a systematic analysis of how asymmetric scaling affects each variant would guide future architecture design.

Practical Applications

Production LLM inference and training. FlashAttention-4 directly benefits deployments on B200/GB200 systems, which are becoming standard for large-scale AI. The 1.3× speedup over cuDNN translates to: - Training throughput: Faster forward/backward passes reduce training time - Inference latency: Long-context applications (document reasoning, code completion) benefit directly - Cost reduction: Higher TFLOPs utilization means better hardware ROI

The deterministic mode specifically enables reproducible training for reinforcement learning from human feedback (RLHF) and other applications where gradient nondeterminism causes issues.

Attention variant development. The CuTe-DSL framework lowers the barrier to implementing new attention variants. The paper notes:

"developers have successfully built FlexAttention and block-sparse attention variants on top of FlashAttention-4 without modifying the core framework."

This enables rapid prototyping of research ideas without deep C++ template expertise. For research teams, this could significantly accelerate iteration cycles.

Hardware design feedback loop. The detailed bottleneck analysis provides valuable feedback to GPU architects. The finding that exponential units are 512× slower than tensor cores—and that software emulation becomes necessary—suggests future architectures might benefit from: - Higher-throughput specialized units for common nonlinearities - Tensor cores with built-in softmax acceleration - More balanced scaling across functional units

Integration Guidance

When to prefer FlashAttention-4: - Blackwell-based systems: The optimizations are architecture-specific; on Hopper or earlier GPUs, FlashAttention-3 remains appropriate - BF16 precision: The exponential emulation accuracy analysis specifically targets BF16; FP32 or FP8 workloads require additional validation - Long-context workloads: Gains are most significant for sequence lengths 4k and above (Figures 4, 6) - Causal masking: LPT scheduling provides additional benefits; deterministic mode enables reproducible training with modest overhead

Integration path: - FlashAttention-4 is open-sourced under a permissive license at the referenced GitHub repository - The authors are working to integrate with popular libraries, lowering adoption barriers - cuDNN has already incorporated techniques from this work (versions 9.13+), providing alternative access points

Fallback considerations: - For workloads with unusual head dimensions (non-power-of-2, very small), the tile-size assumptions may not hold - For FP32 precision, the emulation strategy requires higher-degree polynomials or may not be beneficial - For architectures without TMEM or 2-CTA modes (consumer GPUs, older datacenter GPUs), the optimizations are not applicable

The Productivity-Performance Tradeoff Resolution

Perhaps the most significant long-term implication is the demonstration that developer productivity and peak performance are not mutually exclusive. FlashAttention-4 achieves state-of-the-art performance while being written entirely in Python-embedded DSL with 20-30× faster compile times. This challenges the conventional wisdom that high-performance kernels require painful C++ template metaprogramming.

As the paper states:

"By lowering the barrier to entry, our approach enables researchers and engineers with just a few months of GPU programming experience to contribute meaningful extensions without requiring deep expertise in C++ template metaprogramming."

This democratization could accelerate innovation in attention mechanism design, enabling a broader research community to contribute performance optimizations rather than relying on a small set of kernel experts.