Chengshuo Dai
Back to Blog

Hardware-Aware Algorithms: The Magic of FlashAttention

Model ArchitecturePerformance Optimization

When training large language models, the self-attention mechanism is notoriously expensive. Its time and memory complexity scale quadratically with the sequence length ($O(N^2)$). For a long time, the community tried to solve this by creating "sparse" attention mechanisms—approximations that reduced the theoretical complexity. But then FlashAttention came along and showed us that we were looking at the wrong bottleneck.

The true bottleneck wasn't just the number of FLOPs; it was memory access. Specifically, moving data between the GPU's slow, high-capacity High Bandwidth Memory (HBM) and its fast, low-capacity Static Random Access Memory (SRAM).

Tiling and Recomputation

FlashAttention achieves exact attention (no approximations) while being significantly faster and more memory-efficient. It does this through two key techniques:

  1. Tiling: Instead of computing the entire $N \times N$ attention matrix at once (which requires writing massive intermediate tensors to HBM), FlashAttention breaks the input into blocks (tiles). It loads a block of Queries, Keys, and Values from HBM to SRAM, computes the attention for that block, and writes the final output back to HBM. This drastically reduces memory reads and writes.
  2. Recomputation: During the backward pass, standard attention needs the intermediate attention matrix (which is huge) to compute gradients. FlashAttention doesn't store this matrix. Instead, it recomputes it on the fly during the backward pass. While this requires more FLOPs, it saves so much HBM read/write time that the overall backward pass is actually faster.

Personal Reflection

FlashAttention was a massive "aha!" moment for me in my deep learning journey. It fundamentally changed how I view algorithm design. We often get caught up in Big-O notation, assuming that fewer operations always mean faster execution. FlashAttention proved that a hardware-aware algorithm with more FLOPs can be faster than a theoretically "efficient" algorithm if it respects the memory hierarchy of modern accelerators.

It also highlights a gap in how we teach computer science. We teach algorithms in a vacuum, assuming uniform memory access costs. But in the real world, especially on GPUs, the cost of moving a byte of data can be orders of magnitude higher than the cost of multiplying two numbers. FlashAttention is a masterclass in bridging the gap between theoretical math and physical silicon.


Reference: