Vicente Rodríguez

July 3, 2025

Triton library, GPUs, and Flash Attention

This blog post contains a brief talk about certain terms related to GPU processing, in specific, the triton library to code kernels, the GPUs under the hood work, and how the Flash attention implementation takes advantage of the GPUs design to improve and obtain a more efficient performance compared to the vanilla attention implementation.

Important concepts

Memory

GPUs are composed of two types of memory locations, a larger capacity memory called HBM, around 16GB, 32Gb or even 100Gb, and a smaller capacity memory called SRAM, around 20MB, yet a way higher throughput capacity memory. All computations are executed inside the SRAM, leading to the different kernels/programs the need to load data or tensors from HBM to SRAM. Loading data from HBM to SRAM is quite expensive, even resulting in cases where is more efficient to re-compute operations than transferring data between memories.

SM

SRAM is divided into different blocks and each block is assigned to a group of cores in most GPUs. This design allows parallel computations where cores in a group can only access to their assigned SRAM block memory.

Block Size

The block size determines the amount of elements that each kernel is assigned to.

Tiling

Parallelization is the core function of GPUs. Tiling is the implementation of the former idea by allowing tensors, let’s say NxD (where N is the number of elements),to be loaded into B groups/kernels (B = N/BLOCK_SIZE). Often this approach helps fitting large tensors within the limited SRAM capacity.

You can define a BLOCK_SIZE larger than the input tensor, in this type of cases is important to apply a mask to indicate the limit of the size of the input tensor.

Grid

The grid value determines how many kernel instances, or often called blocks, are launched in parallel (tiling). The grid size is calculated as M/BLOCK_SIZE where M is the shape or size of the tensor. If M=256 and BLOCK_SIZE=32, 8 groups/kernel instances are launched, each loading BLOCK_SIZE elements, in this case 32.

A multidimensional grid can be defined to handle higher dimension matrices, (batch, sequence, channels, etc).

A kernel instance can be understood as the execution of a function. Each kernel instance computes a different part of the matrix but the same function.

Moreover, inside a kernel instance (function execution) the matrix of assigned elements (BLOCK_SIZE) can be further subdivided to perform smaller and faster computations. For instance, in a matrix multiplication operation, a matrix of shape 256x256 can be loaded into groups such as 16 rows of 16x256 elements, each group computed and accumulated, until the entire 256 rows are calculated.

For this reason is common to find for loops inside the kernels code. For loops by default runs sequentially loading group by group, until the entire BLOCK_SIZE elements are computed. Furthermore, in the presence of independent computations, parallel execution of the for loop can be implemented. The row-wise computation of the soft-max activation is a good example of this implementation, where each iteration of the for loop computes an independent row from the input matrix.

Program ID

A unique program id is assigned to each kernel, and it is used to determine the chunk of data to load.

pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

Triton Example

Now, let’s see an example of triton code:

import triton
import triton.language as tl

@triton.jit
def add_kernel(
        x_pointer,
        y_pointer,
        output_pointer,
        vector_size,
        BLOCK_SIZE,
):
    pid = tl.program_id(axis=0) # Used to obtain the id of the current program/kernel
    # where a program is the execution of one add_kernel function

    # for a vector of size 256 and a BLOCK_SIZE of 64, each program will access to:
    # [0:64] [64:128] [128:192] [192:256]

    block_start = pid * BLOCK_SIZE # obtain the start pointer 
    # offsets contains all the pointers needed to load all the elements
    # for the current kernel/program execution
    offsets = block_start + tl.arange(BLOCK_SIZE) 

    # A mask is employed to mask out extra pointers/elements, different to vector_size
    mask = offsets < vector_size

    # load the vectors x, y from memory
    x = tl.load(x_pointer + offsets, mask=mask)
    y = tl.load(y_pointer + offsets, mask=mask)

    output = x + y

    # It is our job to write the result of the operation in memory
    # using the output pointers
    tl.store(output_pointer + offsets, output, mask=mask)

Kernel functions take as input pointers instead of complete tensors, thus for each call execution and using the program ID as base, tensors are loaded, block by block.

In order to call the previous function, a grid must be defined:

n_elements = 992
Y = torch.rand(n_elements) # pytorch tensor
X = torch.rand(n_elements) # pytorch tensor

grid = (triton.cdiv(n_elements, BLOCK_SIZE))

add_kernel[grid](
    x, y, output, 
    vector_size, BLOCK_SIZE
)

Where grid is just defined as: M/BLOCK_SIZE

Multidimensional Grids

As explained, often a multidimensional grid is needed to better represent the input data. In the previous code, a 1D grid was employed following the input shape, one dimensional vectors.

But if a more complex situation is presented, such as a matrix multiplication, where 2d matrices AxD DxB are involved.

A 2D grid can be employed to handle individually, and more importantly in parallel, rows and columns from the input matrices:

pid_A = tl.program_id(axis=0) # what row group are we handling?
pid_B = tl.program_id(axis=1) # what column group are we handling?

offsets_A = pid_A * BLOCK_SIZE_A + tl.arange(0, BLOCK_SIZE_A)
offsets_B = pid_B * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
grid = (triton.cdiv(A, BLOCK_SIZE_A), triton.cdiv(B, BLOCK_SIZE_B))

In this way, each kernel execution loads:

Stride

To correctly load in memory the values of a matrix or tensor, a stride value gets involved, indicating the amount of elements that must be skipped in memory to move to the next dimension (A to D or D to B)

The previous code lack of a fundamental part, the amount elements per row or column for matrices AxD and DxB respectively, to be loaded:

offsets_D = tl.arange(BLOCK_SIZE_D)

BLOCK_SIZE_D can also be interpreted as the number of columns loaded from AxD (Just the elements per row) and the number of rows loaded from DxB (Just the elements per column).

Now, going back to the stride values, these are employed as follows:

stride_a_A = tensor_a.stride(0)
stride_a_D = tensor_a.stride(1)

stride_b_D = tensor_b.stride(0)
stride_b_B = tensor_b.stride(1)

For tensors of shape a=(8, 4), b=(4, 12)

stride_a_A = 4, stride_a_D = 1. Meaning: to go from row 0 to row 1, 4 elements need to be skipped.

stride_b_D = 12, stride_b_B = 1. Meaning: to go from row 0 to row 1, 12 elements need to be skipped.

The stride of the last dimension is always 1, indicating that you can no move to a next dimension.

Therefore, to correctly obtain the memory location of the elements to load, one compute:

offsets_a = offsets_A[:, None] * stride_a_A + offsets_D[None, :] * stride_a_D
offsets_b = offsets_D[:, None] * stride_b_D + offsets_B[None, :] * stride_b_B

Where:

# [:, None] transforms a list: [1, 2, 3] to
# [[1], [2], [3]]
# [None, :] transforms a list: [1, 2, 3] to
# [[1, 2, 3]]

Then by multiplying offsets_A by stride_a_A this indicates, move N elements in the current list to access the elements from the next row. If:

offsets_A = tensor([0, 1, 2, 3])

By doing:

offsets_A[:, None] * stride_a_A

One obtains:

tensor([[ 0],
        [ 4],
        [ 8],
        [12]])

Still 4 elements but different memory locations. And by adding:

previous_result + offsets_D[None, :] * stride_a_D

One is adding the amount of elements (columns) per row wanted, if BLOCK_SIZE_D is 2:

tensor([[ 0,  1],
         [ 4,  5],
         [ 8,  9],
         [12, 13]]),

In total 8 elements are loaded in one group or block, now with the correct memory locations.

A Jupyter notebook can be found in this Link to better understand all these operations.

Grid is defined using triton.cdiv(A, BLOCK_SIZE_A), triton.cdiv(B, BLOCK_SIZE_B). Inside the kernel, BLOCK_SIZE_D is also used in a similar way to iterate through groups of columns and rows of matrices AxD, DxB:

for d in range(0, tl.cdiv(D, BLOCK_SIZE_D)):
    # Load the d group of columns from *AxD*
    # Load the d group of rows from *DxB*
    ...

    # Move the pointer to the next group
    offsets_A += BLOCK_SIZE_D * stride_a_D
    offsets_B += BLOCK_SIZE_D * stride_b_D

Flash Attention and Block-Wise Softmax

The traditional attention implementation requires multiple loads and writes to the HB(M) memory:

Tensors S and P cannot be stored in the SRAM at the same time given their commonly large size, what leads to several accesses to the HBM memory.

Tiling is the approach implemented in flash attention to address this issue, where attention tensors Q, K and V are divided into smaller blocks capable of fitting into SRAM. Besides, softmax calculation is fused along with the computation of S to further reduce the number of writes and loads to the HBM.

The softmax activation needs a normalization factor computed by summing the exponentials of all the elements in the sequence. As a result, the entire sequence needs to be known. However, given that flash attention digests tensors in blocks, an incremental calculation of the normalization factor is utilized, where a global state is updated after each block computation, always updating the previous state to accommodate the new information.

Each arrow is a different Program ID which direction indicates how tensors are loaded block by block:

    attention_inner
           N of K & V

          ------------> (1 execution)
          ------------> (2 execution)
N of Q    ------------> (3 execution)
          ------------> (4 execution)
          ------------> (5 execution)

Inside the function attention_inner a for loop iterates over blocks of K and V tensors (arrows left to right). The function is called in multiple occasions, each time processing a different block of Q tensors (arrows top to bottom).

Visualized in a different way, for a Q of shape NxD and for a pair K, V of shape MxD, a subset of Q is ingested (Q[0:5], Q[5:10], … Q[10:N]) in each function execution, and K and V are loaded block by block. (The whole embedding dimension is loaded).

Instead of a full QK calculation like: Nxd . dxN, each function execution: 5xd . dxN, loads and computes 5 row queries.

Given the nature of matrix multiplications, it is possible to independently compute each result element of the output matrix. In this case, a group or block of queries which are summed until the whole row or column of elements is computed. The execution of one function does not affect the execution of the rest of the functions. Due to this independency, the computation of elements hidden by the casual mask in the output matrix can also be skipped.

As an extra note, these computations are also possible to skip in a traditional matmul operation in frameworks like PyTorch, yet since matmul is not designed to skip operations, this leads to a poorly performance compared to computing the whole matrix and then masking.

As stated, flash attention fuses the softmax calculation into the matrix multiplication between QK, this fused operation creates a small dependency between computations that results in the calculation of the normalization factor, yet the matrix multiplications are still independent.

Back-propagation and Re-computation

While in vanilla attention the output matrix, after computing the product of scores and values, is stored in the HB(M) memory and loaded back for the back-propagation computation, in flash attention all the block-wise operations are re-computed. Even when this might seem as a worse implementation, recomputing all the operations is way more efficient than storing the entire results tensors O, P, S and loading them back to calculate the gradients. The only information stored is the small normalization factor value, leading to easier computations compared to the forward pass.