uv venv
source .venv/bin/activate
uv run src/find_optimal_group.py --heads={H}Replace H with an integer. Default H = 8.
Transformers are widely used across various AI domains including the architecture of Large Language Models (LLMs).
However, their computational demands are massive primarily because they spend a significant amount of time and energy accessing the Key-Value (KV) cache of previously processed tokens.
This process degrades inference throughput especially for generating long, detailed responses.
Grouped-Query Attention (GQA) offers a smart solution to drastically cut down on memory usage and speed up inference.
In this article, I’ll examine a cost-effective method for optimizing the GQA configuration, comparing the performance with its counterparts like Multi-Head Attention (MHA) and Multi-Query Attention (MQA).
Attention mechanism is a key component in the Transformer architecture to weigh the importance of each word (token) in the input embedding.
The below diagram shows how the attention mechanism works, leveraging the Q-K-V mechanism:
Figure A. The Q-K-V mechanism in the encoder (left) and Transformer architecture (right) (Created by Kuriko IWAI)
In the attention layer (red boxes in Figure A), the Query (Q, orange box), Key (K, grey box), and Value (V, blue box) vectors are generated by a linear transformation of the input embedding X, such that:
Q = XW_{Q}, K = XW_{K}, and V = XW_{V}
where W_{Q}, W_{K}, and W_{V} are learnable weight matrices.
These vectors hold specific information such that:
-
Q holds a query for the current token, asking the information in the input embedding to look for to understand the full context,
-
K: The labeled information of each token that Q will look for, and
-
V: The actual content (information payload) of each token,
and are used to compute the attention weights Z:
where Q, K, V are the Q, K, V vectors and d_k is the dimension of the Key vector (scaling factor).
This final result is passed on to the feed forward layer to generate a context-rich representation of the input.
Grouped-Query Attention (GQA) is a type of attention mechanisms designed to reduce the memory bandwidth requirements and latency during the decoding phase.
Below diagram illustrates how GQA works:
Figure B. How GQA works (Created by Kuriko IWAI)
In Figure B, after receiving the input embedding X, the network creates eight heads and applies a linear transformation to generate eight corresponding Query vectors (Q(0) to Q(7)).
Then, it groups the Query vectors into four groups (Groups j=0 to j=3).
Each group computes the attention weights (Z(0) to Z(3)), while sharing the same Key and Value projections in the group to reduce the KV cache size.
Lastly, the network concatenates all attention weights from the groups and performs a liner transformation to generate the final output O.
Mathematically, the process is generalized by using i-th head (i ∈ {0, 1, …, H-1}) and j-th group (j ∈ {0, 1, …, G-1}):
where:
-
Z_i(G): The attention output of thei-th query head in thej-th group, -
Q_i: The Query vector corresponding to thei-th query head, -
K_{Gj}: The Key vector of thej-th group, -
d_k: The dimension of the Key vector, and -
V_{G_j}: The Value vector of thej-th group.
The GQA layer concatenates all outputs Z_i(G)’s of individual query heads and perform linear transformation:
where:
-
O_{GQA}: The final output of the GQA layer (which passed onto the feed forward layer), and -
W^O: The weight vector of the output layer.
GQA is a variant of the standard attention mechanism: Multi-Head Attention (MHA).
The below diagram compares GQA with MHA and its another variant, Multi-Query Attention (MQA):
Figure C. Comparison of attention mechanisms - Left: MHA, Center: GQA, Right: MQA(Created by Kuriko IWAI)
MHA (left in Figure C) uses full sets of Key and Value vectors for each query head.
For example, in Figure C, MHA has eight query heads, so has eight Key vectors and eight Value vectors.
As these query heads represent different subspaces (aspects) of the input sequence, MHA is most expressive in enriching contextual understanding.
However, its primary challenge is massive memory and cost consumption:
-
High computational cost with many matrix multiplications,
-
Large memory footprint required for backpropagation, and
-
High memory consumption during inference (decoding) due to its massive Key-Value (KV) cache size.
On the other hand, MQA (right in Figure C) is the opposite extreme in saving memory and computation cost by sharing a single Key and Value vectors with all eight query heads.
This significantly reduces the KV cache size, and so memory bandwidth requirements, making the inference extremely fast.
But MQA compromises computational efficiency due to contention as all query heads compete to access the shared Key and Value data.
GQA balances the trade-off between performance and speed by grouping multiple query heads.
In Figure B, GQA creates four groups by grouping two query heads, nearly half the KV cache size compared to MHA.
This helps GQA to achieve faster inference (low latency in generating tokens) while securing near MHA quality of the attention.
The below table simplifies key characteristics:
| MHA | GQA | MQA | |
|---|---|---|---|
| Number of query heads | H |
G (1 < G < H) |
1 |
| Memory footprints | High | Mid | Low |
| Inference speed | Slow | Mid | Fast |
| Focus | Quality | Quality & Speed | Speed |
Table A. Attention mechanism comparison (Created by Kuriko IWAI)
GQA offers a marginal solution for Transformers, but a quesion remains: what are the optimal grouping strategies?
I’ll explore this in the next section.
Whether GQA achieves better performance and faster inference speed depends heavily on the grouping strategies.
In this section, I’ll optimize query head grouping of the GQA Transformer, and then compare its performance and inference speed with three Transformers:
-
Standard Transformer (as a baseline),
-
MHA Transformer, and
-
MQA Transformer.
All transformers have the encoder-decoder architecture and are trained on the English-French translation dataset.
Merging query heads can compromise GQA performance compared to its MHA counterparts.
I’ll use Procrustes analysis to enhance the similarity among the heads in each group referring to the research paper: Aligning Attention Heads Before Merging Them (arXiv:2412.20677v2).
This method computes the cosine similarity scores between every two elements in the Value vector and finds the optimal grouping with the highest total similarity scores.
The optimal group strategies for inference speed depends on whether the system is:
-
Memory-bound: Waiting for data from the KV cache, common on budget GPUs or long sequences, or
-
Compute-bound: Waiting for calculation to finish, common on high-end GPUs or very large batches.
In this analysis, I’ll systematically sweep across:
-
Sequence length:
N = 512to2,048tokens and -
Batch size:
B = 1to16requests,
while the memory bandwidth is set constant (no hardware diversity) to see which configurations maximize the benefits of the GQA Transformer.
Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA (arXiv:2412.20677)

