From Vanilla Transformer to FlashAttention: A Quantum Leap in GPU Efficiency

Manish Agarwal
15 min readOct 10, 2023

--

Introduction

In the world of deep learning and artificial intelligence, there’s a constant drive to make models faster and more efficient. We’ve come a long way from the basic Transformer model to something groundbreaking called FlashAttention. But how did we reach this point? What’s so special about FlashAttention? In this blog, we’ll take a simple journey through the history of attention mechanisms in AI and discover how FlashAttention is changing the game when it comes to making GPUs work smarter.

At the core of any modern deep learning model lies the attention mechanism — a fundamental operation that plays a pivotal role in understanding the relationships between elements in a sequence. However, as neural networks have grown larger and sequence lengths have become more massive, standard attention mechanisms have started to show their limitations.

Standard attention operates with a time complexity of O(N²), where N is the sequence length. As a result, it can quickly become a computational bottleneck, hindering the training of large models. To address this challenge, researchers have developed various approximate attention methods like Reformer, Smyrf, Performer, and more. These methods aim to reduce the computational requirements, primarily in terms of floating-point operations (FLOPs). However, they often overlook the overheads from memory access (I/O), which can also significantly impact the overall speed

Dropout, softmax, and masking operations, which are all element-wise operations, significantly impact the computation time due to their memory-bound nature.

As you observe in the left bar, masking, softmax, and dropout operations are the primary contributors to computation time, overshadowing even matrix multiplication, despite the fact that most of the floating-point operations (FLOPS) reside within matmul.

However, there’s a silver lining. Memory isn’t a single, uniform entity; it’s hierarchical. In general, faster memory comes at a higher cost and with smaller capacity.

Before we dive into Flash Attention, let’s first grasp the fundamental concepts of GPUs (Graphics Processing Units). This will make our upcoming exploration much smoother.

GPU Memory Hierarchy

Let’s start by looking at the memory hierarchy of a GPU, focusing on the NVIDIA A100 GPU with 40GB of High Memory Bandwidth (HBM). This GPU features a complex memory hierarchy, and a simplified diagram of it is shown above.

At the heart of this hierarchy are 108 streaming multiprocessors (SMs), each equipped with 192KB of SRAM (Static Random-Access Memory).

192*108 = 20700 KB(approx) = 20MB

The SRAM is blazingly fast but limited in size compared to HBM. This architecture highlights an essential aspect of GPU design: the trade-off between speed and capacity.

In terms of raw compute power, the A100 GPU can achieve a theoretical peak throughput of 312 TFLOPS (Tera Floating Point Operations) when using BFLOAT16 with Tensor Cores.When you combine BFLOAT16 with Tensor Cores, you can take full advantage of the A100 GPU’s computational capabilities for tasks like deep learning, where massive matrix operations are common.

Tensor Cores” are hardware units within the GPU optimized for performing matrix multiplications and other tensor operations, which are fundamental to many machine learning algorithms.

BFLOAT16” is one such data format. It’s a 16-bit floating-point representation designed for machine learning and deep learning workloads. It strikes a balance between precision and performance.

Over time, computational capabilities have advanced significantly, outpacing memory access speeds. As a result, many operations are increasingly bottlenecked by the time it takes to access data from memory (HBM).

GPU Execution

GPUs operate by employing a vast number of threads to perform a specific operation, often referred to as a “kernel.” These kernels load data from HBM into registers and SRAM, process it, and then write the results back to HBM after computation.

To gauge the quality and performance of GPU execution, we rely on a critical metric: arithmetic intensity. This metric represents the number of arithmetic operations performed per byte of memory accessed. It helps identify operation bottlenecks and is divided into two categories: Compute-Bound and Memory-Bound Operations.

Compute-Bound Operations:

  • In compute-bound operations, the bottleneck is the computation itself. These operations are characterized by having a high number of arithmetic operations relative to memory accesses.
  • Examples include matrix multiplication with a large inner dimension and convolution with a high number of channels.

Memory-Bound Operations:

  • In memory-bound operations, the bottleneck is memory access. These operations are characterized by having a high number of memory accesses relative to computational operations.
  • Examples include most elementwise operations (activation, dropout) and reduction operations (sum, softmax, batch normalization, layer normalization).

Math bandwidth ”the rate at which math unit operation can be conducted by a processor, and is usually expressed in units of operations/second (OPS). If the data being processed are floating point data, the unit FLOPS is more commonly used. The math bandwidth could be queried from the hardware specification, such as NVIDIA A100 GPU and Intel CPUs.

Memory bandwidth” is the rate at which data can be read from or stored into a semiconductor memory by a processor, and is usually expressed in units of bytes/second (B/s). The memory bandwidth could be queried from the hardware specification, such as NVIDIA A100 GPU, or computed theoretically, such as Intel Core-X CPUs.

For a deeper understanding of the mathematical aspects of GPU architecture, you can refer to Lei Mao comprehensive blog post. This will provide you with a wealth of detailed information on the topic.

Kernel Fusion: Streamlining GPU Operations

Kernel fusion is a technique used to optimize GPU operations, particularly those involving memory-bound tasks. Instead of performing each operation separately, it combines multiple elementwise operations into one, making the most of data loaded from memory.

Picture it like this: in a standard process, you’d load data, do one operation, save the result, and repeat for each operation, which can be slow for large datasets. Kernel fusion changes the game by loading the data once and executing all the operations in a single step before writing the results back to memory. This significantly speeds up the process, making GPU computations more efficient.

However, keep in mind that during model training, where intermediate values must be stored for the backward pass, kernel fusion’s benefits are somewhat reduced. This is because you still need to write these intermediate values back to memory. Nevertheless, for tasks where memory-bound operations are a bottleneck, kernel fusion is a valuable tool to enhance GPU performance.

Background — Traditional Attention Mechanisms

If you’re familiar with transformers , you’ve likely encountered this equation before:

In this equation, we have sequences represented by matrices Q, K, and V, where ’n’ is the sequence length, and ‘h’ is the head dimension. The result of the attention operation, denoted as ‘A,’ can be understood step by step.

Scaled Dot Product Attention is at the heart of this equation. In typical attention implementations, the matrices Q and K occupy significant memory (HBM) due to their size, which is proportional to ‘n.’ Additionally, most operations involve memory-bound and elementwise computations. For instance, we have operations like softmax applied to QK, masking applied to QK, and dropout applied to the output.

This standard approach has its drawbacks, primarily in terms of memory usage and the computational time it demands. The large memory footprint and the memory-bound nature of operations can slow down the overall processing time.

Notation remark: Q — queries, K — keys, V — values, S — scores, P — probabilities, O — outputs.

The standard implementation appears to overlook the operational nature of hardware. It treats High Bandwidth Memory (HBM) load and store operations as if they have zero cost, lacking an “IO-aware” approach.

Now, let’s approach this problem from the ground up, considering how we can make this implementation more efficient in terms of both time and memory.

The most straightforward improvement lies in eliminating redundant HBM reads and writes. Why would we write the intermediary result S back to HBM, only to load it again for softmax computation? Instead, we can retain it in SRAM, execute all the intermediate steps, and then write the final result back to HBM.

This approach aligns with what compiler experts call “kernel fusion,” a fundamental low-level optimization technique in deep learning.

With the foundational context in place, let’s now dive deeper into the Flash Attention algorithm.

Unveiling FlashAttention — A Closer Look at the Algorithm

The materialization of the N*N attention matrix on the HBM and its repeated reading and writing is a major bottleneck. To solve this, two main things need to be done as per the Paper-

  1. Computing the softmax reduction without access to the whole input
  2. Not storing the large intermediate attention matrix for the backward pass

Two established techniques, namely tiling and recomputation are used to solve this.

Tiling

(used during both forward & backward passes)

Imagine breaking down a large puzzle into smaller pieces to solve it more efficiently. Tiling does something similar for attention computation.

  • Instead of processing the entire input at once, tiling splits it into manageable blocks.
  • Incrementally, these blocks are used to perform the softmax operation, reducing the need to access the entire input in one go.
  • This restructured approach significantly cuts down on memory usage and speeds up computation.

Recomputation

(used in the backward pass only)

  • When solving complex problems, sometimes it’s faster to redo a few calculations rather than storing all intermediate results.
  • Recomputation in FlashAttention takes this approach. It stores the softmax normalization factor from the forward pass.
  • During the backward pass, instead of reading the entire intermediate attention matrix from HBM, it quickly recalculates attention on-chip using this stored factor.
  • While this may increase the number of floating-point operations (FLOPs), it leads to much faster computation and uses significantly less memory.

The result? FlashAttention not only runs faster, with impressive speedups of up to 7.6 times observed in models like GPT-2, but it also operates more efficiently, reducing memory usage to a linear scale with sequence length. These innovations make FlashAttention a game-changer in optimizing attention mechanisms for deep learning models, offering both speed and memory efficiency.

Core Concept of the Algorithm

Before delving into the algorithm, it’s essential to grasp the concepts of softmax, safe softmax, and online normalizer calculation for softmax.

Softmax Algorithm -1

First, let’s grasp the concept of the Naive Softmax Equation before delving into the algorithm. Consider an array, e.g., [0.2, 0.3, 0.4, 0.5], and our goal is to calculate the softmax.

To do this, we start by finding the exponential sum of all the elements in the array, which is approximately 5.711. Then, we compute the probability for each element by dividing its exponential value by the exponential sum. The key here is that the sum of all these probabilities should always equal 1, ensuring a valid probability distribution.

Now, moving to the Naive Implementation (see Algorithm 1), it scans the input vector twice. First, it calculates the normalization term (dV), and then it computes the output values (yi). This process results in three memory accesses per vector element: two loads and one store.

However, there’s a challenge with this approach. On real hardware, where the range of representable numbers is limited, line 3 of Algorithm 1 can encounter overflow or underflow issues due to the exponentials. To address this problem, a modified approach known as “Safe Softmax” is introduced.

Algorithm -2 Safe-Softmax

All major deep learning frameworks, including TensorFlow, PyTorch (with Caffe2), MXNET, Microsoft Cognitive Toolkit, and Chainer, have adopted the safe version of Softmax computation.

In this example, let’s consider an array, e.g., [3, 4, 5]. To perform the Safe Softmax operation:

  1. We first identify the maximum value in the array. In our case, it’s 5 (as seen in line 3 of Algorithm 2).
  2. Next, we subtract this maximum value (5) from each element in the array. This step ensures numerical stability by preventing potential overflow or underflow issues. So, we get [3–5, 4–5, 5–5], which simplifies to [-2, -1, 0] (as per line 7 in the code).
  3. Now, we calculate the exponential sum for these adjusted values. We apply the exponential function to each of them and sum the results. For example, exp(-2) + exp(-1) + exp(0).
  4. Finally, we compute the probability distribution. To do this, we divide each element’s exponential value by the exponential sum. This gives us the probabilities associated with each element.

The key takeaway here is that by subtracting the maximum value from all elements before applying the exponential function, we ensure that the computation remains stable and doesn’t encounter issues like overflow or underflow. This is a fundamental step in Safe Softmax computation.

This approach is efficient because it leverages the prior calculations for the existing elements and only incorporates the new element’s contribution using a straightforward adjustment.

Nonetheless, the Safe Softmax method requires three iterations over the input vector: the initial iteration finds the maximum value (mV), the second iteration computes the normalization term (dV), and the third iteration derives the final values (yi), all outlined in Algorithm 2. This translates to a total of four memory accesses per vector element.

  1. Now, when you introduce the new element [10], you don’t need to redo all those steps. Instead, you can follow these simplified actions:
  • Calculate the exponential of the new element: exp(10).
  • Multiply the existing exponential sum by e^(5–7) (e^(maximum of the old elements — maximum of the new elements)).
  • Add e^(maximum of the new elements — maximum of the new elements) to the updated exponential sum.

This above intution will aid to understand the concept of Flash attention.

Algorithm -3 Safe Softmax with online normalizer calculation.

In the context of online normalizer calculation, the fundamental concept involves streamlining the aforementioned steps into a single pass over the input vector. This is achieved by merging the operations from lines 6 to 8 into lines 2 to 4, resulting in a new algorithm (Algorithm 3).

By doing so, this optimization significantly reduces memory accesses and eliminates computational redundancies. As a result, the computation becomes faster and more memory-efficient, which is particularly advantageous when working with extensive datasets and deep learning models.

https://arxiv.org/pdf/1805.02867.pdf

So let start with the Flash attention algo……………………..

the algorithm’s main concept revolves around dividing the input matrices Q, K, and V into manageable blocks. These blocks are then loaded from the relatively slow High Bandwidth Memory (HBM) to the faster Static Random-Access Memory (SRAM).

Once the blocks are in SRAM, the algorithm calculates the attention output for each block. It’s important to note that each block’s output is scaled correctly using the appropriate normalization factor before summing them up. This summation process ensures that the final result is accurate and aligned with the original problem’s requirements.

Example Scenario:

This is the core idea behind softmax tiling. By recursively repeating this computation across all of the blocks we end up with the correct softmax output.

The equations may seem alien at first, but don’t worry :) we’ll break them down into simpler terms for better understanding.

I’ve used a 1D vector to explain, but the same logic applies to a 2D vector. Let’s consider a sample array with 6 elements: [1, 2, 3, 4, 5, 6]. Now, divide it into 2 blocks, x1 = [1, 2, 3] and x2 = [4, 5, 6]. We’ll apply the safe softmax separately to both blocks. After that, we find a new maximum, m_x, which is the maximum of m1 and m2, and calculate l_x using the equations we discussed earlier. We’ll have l1 and l2 variables. Finally, we merge them again to obtain the final output, O_x, using the same logic we used to find probabilities in the previous example.

Great! 😄 We’ve finally grasped the fundamentals of Flash Attention! 🚀

where t∈R is some softmax scaling factor (typically 1/root(d)), MASK is some masking function that sets some entries of the input to −∞ and keep other entries the same, and dropout(𝑥,p) applies dropout to 𝑥 elementwise (i.e., output x/(1−p) with probability 1−p and output 0 with probability p for each element 𝑥)

Let’s dive into the Flash Attention algorithm, step by step. 📝💡

Step 0: Memory Allocation

  • The HBM (High Bandwidth Memory) capacity is measured in gigabytes (GBs).
  • Allocating memory for Q, K, and V is not an issue as long as it fits within the HBM capacity.

Step 1: Define Block Sizes(line-2 algo)

  • Determine the row and column block sizes.
  • The size is calculated as ceil(M/4d), where M is the memory capacity and d is the dimensionality of the query, key, and value vectors.
  • This size allows us to maximize SRAM (on-chip memory) usage.
  • example -: Suppose we have M = 800 and d = 4. In this case, the block size would be (800 / (4 * 4)) = 50. So, for this scenario, we would load blocks of 50 q, k, v, o vectors at a time. This approach helps minimize the number of reads and writes between HBM and SRAM, making the process more efficient

It’s a good idea to remember this image (it will become clearer shortly):

Step 2: Initialize Matrices (line-3 algo)

  • We set the output matrix O to all zeros as it will accumulate results.
  • Initialize l, which holds the cumulative denominator for the softmax (sum of exp scores).
  • Initialize m to negative infinity, as it stores the maximum scores for each row and needs a lower initial value for comparison.

Step 3: Split Vectors (line-4 algo)

  • Split Q, K, and V into blocks according to the block sizes determined in Step 1.

Step 4: Split Matrices (line-5 algo)

  • Similarly, split O, l, m into blocks with the same size as Q.

Step 5: Loop Over Columns (line-6 algo)

  • Begin looping over columns, which represent key and value vectors.

Step 6: Load Key and Value Blocks (line-7 algo)

  • Load K_j and V_j blocks from HBM to SRAM. This ensures efficient memory usage.

Step 7: Loop Over Rows (line-8 algo)

  • Start an inner loop over rows, representing query vectors.

Step 8: Load Query and Output Blocks (line-9 algo)

  • We load Q_i (B_r x d) and O_i (B_r x d) blocks, as well as l_i (B_r) & m_i (B_r) into SRAM. You might wonder how l_i & m_i fit into SRAM when the block size was chosen to accommodate K_j, V_j, Q_i & O_i. It’s likely that they are stored in registers, which are a part of GPU memory hierarchy. Please note that this is a simplification, and someone with practical CUDA implementation experience might provide more precise details.

Step 9: Compute Scores (line-10algo)

  • Calculate the dot product between Q_i (B_r x d) and the transpose of K_j (d x B_c) to obtain the scores (B_r x B_c). It’s important to note that we don’t calculate the entire NxN S (scores) matrix. We only compute a part of it, denoted as S_i_j.
  • To illustrate, let’s consider a hypothetical scenario: assume the outer loop index is j (j=3), the inner loop index is i (i=2), N represents 25 tokens in our sequence, and the block size is 5 (assuming 1-based indexing):
  • In this case, we’ve essentially computed the attention scores for tokens 6–10 with respect to tokens 11–15 in our input sequence. These scores are exact and will remain constant throughout (unlike the softmax results, which evolve during the process).

Step 10: Compute Intermediate Values (line-11 algo)

  • Compute m~_i_j, l~_i_j, and P~_i_j using the scores computed in the previous step.
  • m~_i_j finds the row-wise maximum for each row in the current block.
  • P~_i_j is computed by applying element-wise operations for normalization and exponentiation.
  • l~_i_j represents the row-wise sum of the P~_i_j matrix.

Step 11: Update Maximum and Normalization Terms (line-12 algo)

  • Calculate updated values for m_i and l_i, taking into account the values obtained in the previous steps.

Step 12: Attention Computation (line-13–15 algo)

  • This step involves complex computations that are analogous to calculating the softmax but in a more memory-efficient way.
  • It iteratively refines the attention scores using formulas similar to softmax.
  • The goal is to compute accurate attention scores, considering row-wise max values and cumulative denominators.

Step 13: Write Statistics to Memory (line-16 algo)

  • Write the newly computed cumulative statistics (l_i and m_i) back to HBM.

Complexity Analysis

  • Space complexity is optimized, typically O(N), which is crucial for handling large sequence lengths.
  • The number of HBM accesses is reduced, making the computation faster.

Further resources

In conclusion, we’ve come to the end of this blog. We’re eager to share more informative content with you in the future, with a commitment to delivering straightforward and comprehensible explanations. If you’d like to stay connected and continue the conversation, feel free to connect with me on LinkedIn. Until next time, stay curious and keep learning!

--

--

No responses yet