Flop Calculations for Multihead Attention Layer

October 09, 2025
10 min read

Introduction:

In my previous post, I was talking about Batch Size and how it affects the attention layer and MLP layers in separate ways. I would have probably introduced the idea of sequence lengths but realized that there are some bigger questions that need answering. It would be probably make sense to dive into FLOP calculations a little more deeply before we get to that post.

FLOP calculations will enable us to understand the performance of different parts of our model. Along with time, we can measure FLOPs/sec which tells us the throughput of our model. Given the throughput of the model and the GPU we are working with, we can get the Model Flop Utilization - which is a fraction of the total GPU throughput we are using. Also, given the FLOPs, and the amount of the memory we used, we can get the Arithmetic Intensity - which measures the ratio of the amount of computation to the amount of memory. We wrote in great detail on Arithmetic Intensity in our previous post. Please do check it out. The higher the Arithmetic Intensity, the more compute bound it is. And our aim should always be to take a model that might be memory bound and make it compute bound.

Once we make the model compute bound, we can start asking for better hardware. :D

FLOPs Calculation in Matrix Multiplication

A typical Matrix Multiplication would be like look like this -

C = A * B

Where,

A has a size of MxK,B has a size of KxN, and C has a size of MxN

In pseudo code, it would look like this -

for i in range(M):
    for j in range(N):
        for k in range(K):
            C[i][j] += A[i][k] * B[k][j]

Looking at the pseudo code, it is clear we are doing K multiplications and K additions to get the value of C[i][j]. It's actually K-1 but we can assume it as K for simplicity.

There are MxN such cells and therefore, we do 2KM*N operations.

Each addition and multiplication is a single FLOP. So, the total number of FLOPs is 2*K*M*N.

For a Matrix Multiplication torch.matmul(...),

C = A * B

where A has a size of MxK, B has a size of KxN, and C has a size of MxN, the total number of FLOPs is 2*K*M*N.

Flops Calculation in Linear Layer

A linear layer is a matrix multiplication.

y = W * x + b

It is also represented in pytorch as -

linear_layer = nn.Linear(model_dimension, model_dimension * expansion_factor)
y = linear_layer(x)

if x has a shape of (batch_size, M, N) and expansion_factor is 3, then y has a shape of (batch_size, N, N * 3). This means that W will be a matrix of shape (N, N * 3).

The number of FLOPs will therefore be 2 * M * N * (N * 3) * batch_size.

For a Linear Layer nn.Linear(...),

y = W * x + b

where W has a shape of Nx(N * expansion), x has a shape of (batch_size, M, N), and b has a shape of (N * expansion), the total number of FLOPs is 2 * M * N * (N * expansion) * batch_size.

Flops Calculation in Softmax

Let us say you have a matrix of shape (batch_size, M, N). You want to apply softmax to the last dimension.

softmax_matrix = torch.softmax(x, dim=-1)

What does a softmax operation look like?

exp_x = torch.exp(x)
sum_exp_x = torch.sum(exp_x, dim=-1, keepdim=True) # keepdim is needed for broadcasting in the next step
softmax_x = exp_x / sum_exp_x

In real world, one often takes the max along the last dimension, subtract it from the matrix and then apply softmax. This is done to avoid overflow.

let y <- x - max(x)
exp(x) = exp(y + max(x)) = exp(y) * exp(max(x))
torch.sum(exp(x)) = torch.sum(exp(y)) * exp(max(x))
softmax(x) = (exp(y) * exp(max(x))) / (torch.sum(exp(y)) * exp(max(x))) = exp(y) / torch.sum(exp(y))

Anyways, going by the simplified version, the number of FLOPs will be -

  1. exp_x = torch.exp(x) - This is a single FLOP per element. So, M * N * batch_size FLOPs.
  2. sum_exp_x = torch.sum(exp_x, dim=-1, keepdim=True) - This is a sum operation along the last dimension. So, N * batch_size FLOPs.
  3. softmax_x = exp_x / sum_exp_x - This is a division. It is a single FLOP per element. So, M * N * batch_size FLOPs.

For every element, we have an exponentiation, a division, and a sum. So, the total number of FLOPs is 3 * M * N * batch_size. If we use the max version, it will be 5 * M * N * batch_size to account for the two extra subtractions.

For a Softmax Operation torch.softmax(...),

softmax_x = exp(x) / torch.sum(exp(x))

where x has a shape of (batch_size, M, N), the total number of FLOPs is (3 * M * N) * batch_size.

Flops Calculation in Attention Layer

We will be looking at Multi Head Attention Layer.

A super super simple version of multihead attention layer would be like this -

class MultiHeadAttention(nn.Module):
    def __init__(self, model_dimension, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = model_dimension
        self.num_heads = num_heads
        self.head_dimension = model_dimension // num_heads
        self.qkv_proj = nn.Linear(model_dimension, model_dimension * 3)
        self.out_proj = nn.Linear(model_dimension, model_dimension)
    
    def forward(self, x: Float[Tensor, "b s md"]) -> Float[Tensor, "b s md"]:
        qkv_proj = self.qkv_proj(x)
        q, k, v = qkv_proj.chunk(3, dim=-1)
        q = einops.rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
        k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
        v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)

        qkt = torch.matmul(q, k.transpose(-2, -1)) # shape is (batch_size, num_heads, sequence_length, sequence_length)
        qkt = qkt / math.sqrt(self.head_dimension)

        softmax_qkt = torch.softmax(qkt, dim=-1) # sum along s
        
        attn_output = torch.matmul(softmax_qkt, v) # shape is (batch_size, num_heads, sequence_length, head_dimension)
        attn_output = einops.rearrange(attn_output, "b h s d -> b s (h d)")
        
        return self.out_proj(attn_output)

Let's set the stage before we do the calculations -

  1. Inputs are of shape (batch_size, sequence_length, model_dimension)
  2. In multihead attention, we split the model dimension into num_heads. head_dimension * num_heads = model_dimension
  3. We have a total of num_heads heads.
  4. We pass the inputs through a linear layer to get the query (Q), key (K), and value (V).
  5. We then rearrange the query, key, and value to get the shape (batch_size, num_heads, sequence_length, head_dimension). This is the shape of the query, key, and value for each head.
  6. einops.rearrange is a view operation. It does not move data, it doesn't create a new tensor. Therefore, it doesn't contribute to the FLOP calculation.
  7. Same with the torch.chunk operation.

Let's do the calculations -

  1. qkv_proj = self.qkv_proj(x)
    • This is a linear layer with expansion factor of 3. From our formula, the number of FLOPs is (2 * S * mD * mD * 3) * B.
  2. q, k, v = qkv_proj.chunk(3, dim=-1)
    • This is a chunk operation. It does not move data, it doesn't create a new tensor. Therefore, 0 FLOPs.
  3. einops.rearrange(...) for q, k, and v
    • These are view operations. They do not move data, they don't create a new tensor. Therefore, 0 FLOPs.
  4. qkt = torch.matmul(q, k.transpose(-2, -1))
    • This is a matrix multiplication. From our formula, the number of FLOPs is (2 * hD * S * S) * H * B.
  5. qkt = qkt / math.sqrt(self.head_dimension)
    • This is a division. It is a single FLOP per element. So, ((S * S) * H) * B FLOPs.
  6. softmax_qkt = torch.softmax(qkt, dim=-1)
    • This is a softmax operation. It is ((3 * S * S) * H) * B FLOPs.
  7. attn_output = torch.matmul(softmax_qkt, v)
    • This is a matrix multiplication. From our formula, the number of FLOPs is (2 * S * S * hD) * H * B.
  8. attn_output = einops.rearrange(attn_output, "b h s d -> b s (h d)")
    • This is a view operation. It does not move data, it doesn't create a new tensor. Therefore, 0 FLOPs.
  9. self.out_proj(attn_output)
    • This is a linear layer with expansion factor of 1. From our formula, the number of FLOPs is 2 * S * mD * mD * B.

In total, we have

B * (
    (2 * S * mD * mD * 3) +
    ((2 * S * S * hD) * H) +
    ((S * S) * H) +
    ((3 * S * S) * H) +
    ((2 * S * S * hD) * H) +
    (2 * S * mD * mD)
)

We know that hD * H = mD. So, we can substitute hD * H with mD in the above equation.

B * (
    2 * S * mD * mD * 3 +
    4 * S * S * H +
    4 * S * S * mD +
    2 * S * mD * mD
)

or

B * (
    8 * S * mD^2 # linear layer FLOPs
    +
    4 * H * S^2 + 4 * mD * S^2 # attention layer FLOPs
)

For a Multi Head Attention Layer MultiHeadAttention(...), FLOPs is

B * (8 * S * mD^2 + 4 * H * S^2 + 4 * mD * S^2)

Where:

  • B = Batch size
  • S = Sequence length
  • H = Number of attention heads
  • mD = Model dimension (embedding size)
  • hD = Hidden size per head (hD = mD / H)

So, for a given input, you can plug in your actual values for B, S, H, and mD to calculate the total FLOPs for the multi-head attention layer.

Some quick thoughts that we can draw from the above calculations -

  1. If we increase the sequence length, FLOPs for linear layers increases linearly.
  2. If we increase the sequence length, FLOPs for attention layer increases quadratically.
  3. If we increase the model dimension, FLOPs for linear layers increases quadratically.
  4. If we increase the model dimension, FLOPs for attention layer increases linearly.

Imagine you are on a GPU. While computing the attention layer, one is temporarily creating these giant tensors in the GPU memory which are O(S^2) in memory complexity. As we increase sequence lengths, the GPU is not able to store these tensors in fast Shared Memory and spills them onto the global memory. This is a HUGE bottleneck. Flash Attention fixes this by never materializing this giant tensor and thereby getting away with O(S) memory complexity.

But sometimes, some models still get bottlenecked due to large sequence lengths. Example - vision transformer models which operate in patches and for a large image generate a lot of patches (lot of sequences). One way to get around this is to also move to better hardware. From A100 to H100. From H100 to H200. From H200 to B200, etc.

Conclusion

I will close it here. Writing this blog is already giving me ideas on what I will write next -

  1. I Will hopefully post a follow up where we expand on the "quick thoughts" with concrete numbers and plots and examples. Motivate it with real world examples in various models
  2. I would also like to have a follow up post where I implement calculate method in every nn.Module and insert these calculations so that we can recursively call these methods and add them up to get the total FLOP usage in a super neat manner. Probably in the GPT-2 model codebase.
  3. Using the above, we can devote a blog post on calculating the total FLOPs for GPT-2 model with super high fidelity rather than depending on 6 * num_parameters formula that is used, abused and misused.