LLMs Optimizations
Flash Attention
self-attention is runtime complexity: O(n^2 d) and space complexity: O(n^2) which is time consuming, where n is sequence length and d is hidden dimension.
Tiling technique
1. instead of forming the entire QK^T matrix in memory, FA computes it tile by tile,
Q is split into blocks of 128 queries
K and V into blocks of 128 key/values
2. load a tile of keys and values into shared memory,
compute partial attention scores on the fly (online softmax)
multiply corresponding value tiles
accumulate partial results
3. mode on to next key-value tile, reuse shared memory and repeat
Online softmax
if we apply softmax per tile, we will get wrong normalization on whole sequence, flash attention does online softmax ~ computes global softmax across multiple tiles without storing intermediate scores.

at the end you have the global normalized attention result. memory is O(n) not O(n^2)
each tile fits in to 96KB. computation happens in registers and SRAM avoiding slow VRAM read/writes. in normal attention most runtime is wasted on memory IO not math
Results
up to 3x speedup and 10x less memory for long sequences
GPU memory Architecture
streaming multiprocessor (SM): consider it a mini processor cluster, contains many ALUs, registers. shared mem block ~ 64kb
GPU launches thousands of threads grouped into thread blocks, each block running on a single SM. these threads cooperate using the shared memory, different blocks dont talk directly, they communicate only through global memory.
access time: time it takes to retrieve or write
Multi Query Attention
each attention head has its own KV, which needs to be stored into gpu
all heads are concatenated and projected back to model dimensions. During inference you must store all K and V for each head at each time step.
the idea is like, all heads share the same keys and values but each head still has its own queries. during decoding, each token only needs one set of KV in memory. because each head shares same KV pairs, there is less diversity in each head can attend to. in large scale llm, the drop is minimal.
Grouped Query Attention
instead of sharing one KV across all heads, we group heads. n heads per group, each group share one KV pair. we can experiment with how many heads to keep in a group.
Activation checkpointing
when training a transformer, the biggest memory consumer isnt model weights, its activations. during forward pass each layer produces activations which must be stored during backward pass.
using this technique, we only store a few key checkpoints during forward pass, and we recompute the intermediate activations when we reach them during backward pass
Sequence packing
every input in a batch must have same sequence length. but real datasets have different lengths. so to make all input equal length, we pad the shorter onces with dummy tokens <pad>.
if there is a very large sentence then the rest of the sequences will be padded. this will result in a wasted compute (upto 45%) on meaningless padded tokens
so instead of padding seq tokens up to max length, concatenate multiple short sequences together to fill same legnth as one packed sequence.
"The sky is blue." → 4 tokens
"Hello" → 1 token
"A long paragraph" → 3 tokens
we pack them into a same sequence.
Packed sequence: [The, sky, is, blue, Hello, A, long, paragraph]
Inference optimization technique
KV caching
A KV cache stores previously computed key/value vectors for earlier tokens so that when generating new tokens the model can skip re‑computing full attention over the entire history, making inference much faster.
KV caching optimizations
Stateful caching
stateful caching: instead of discarding KV cache after each request, you persist a cache that spans multiple user interactions as long as there is an overlapping context/prefix between requests.
compute hash for all prefixes of a query,
for a new user query, compute its prefix hashes, then check your cache for longest prefix match that was previously cached, if found you reuse that KV cache for that prefix.
The cache prefixes are stored in a prefix tree structure, and since cache is bounded you combine this with LRU to drop old or less used branches
Speculative decoding
using a smaller draft model to generate responses, using the target model to verify them, achieving 2-3x speedup in inference
Quantization techniques
compressing a model by representing weights/activations with fewer bits instead of standard fp32
Quantization Types
Post Training Quantization
after training a model, weights are converted to lower precision. without further training, PTQ is very cheap to implement compared to retraining. accuracy drops after this method is common
Mixed precision Quantization
instead of forcing every weight or activation to same low precision, mp quantization assigns different bit widths/precisions to different parts of model. deciding which layers to quantize requires profiling, not all devices support mixed precision efficiently
Quantization aware training
QAT simulates quantization effects during training/fine tuning. during forward passes, weights and activations are fake quantized and then de-quantized back to float for gradient computation. this way model learns to adapt to quanziation errors while training.
then at inference time, you apply real quantization, model is already robust to lower precision
Training optimization
Data parallelism
instead of training a model on one GPU with one batch at a time, you replicate the entire model on multiple GPUs/processes and split the input data batch into smaller batches assigning each micro batch to a different GPU, each GPU does a forward + backward pass on its subset independently
after computing gradients locally, you need to sync gradients across GPUs, aggregate each GPU accumulated gradients and then update model params, after sync and update, each replica has same parameter value, then next iteration proceeds
main bottleneck is that it relies on a single process, multi threaded communication leading to inefficient inter GPU communication
Synchronization approaches
at the end of each minibatch, workers need to synchronize gradients, the python GIL interpreter leads to CPU bottlenecks. as number of GPUs increases, overheads gets worse
Bulk synchronous Parallel (BSP)
all workers process their micro batches indepedently, once gradients are computed they synchronize/share gradients via an all reduce operation, compute the global average gradient. the overall iteration time is limited by slowest worker, all GPUs must wait until slowest finishes before sync
Asynchronous Parallel (ASP)
each worker processes data and computes gradients independently, rather than waiting for all workers, a worker sends its update as soon as its done, no waiting for slower workers.
a parameter servver/aggregator applies gradient updates as soon as they come from any worker
stale gradients/stale weights issue, by the time a slow worker update arrives, model params may have evolved, update would be based on outdated state → can degrade convergence stability
Because BSP gives stable convergence but suffers from stragglers, ASP gives speed but unstable convergenece, researchers proposed hybrid schemes that try to combine benefits, allowing some asynchorny but bounding staleness
Distributed Data parallel (DDP)
each GPU gets its own process (instead of multiple GPUs being handled by threads in one process)
model is replicated in every proccess/GPU at the start, parameters are broadcasted from a root process. input data is sharded so each GPU gets a different mini batch, each gpu runs forward and backward locally on its data subset, after backward pass DDP uses ring all reduce algorithm to avoid central bottlenecks.
Ring all reduce algorithm
ddp uses collective communication primitive, typically all reduce to sum and distribute gradients across all GPUs after backward pass
arranges N GPUs in a logical ring.
in a reduce scatter phase, each GPU splits gradient tensor into N chunks, then in N-1 steps each chunk is passed around the ring, cumulatively summed, at the end each GPU holds the sum for one chunk
~ each element is a vector of gradients
GPUA: [a1, a2, a3, a4]
GPUB: [b1, b2, b3, b4]
GPUC: [c1, c2, c3, c4]
GPUD: [d1, d2, d3, d4]
after 1 iteration
GPUA: [a1, a2, a3, a4 + d4]
GPUB: [a1 + b1, b2, b3, b4]
GPUC: [c1, c2 + b2, c3, c4]
GPUD: [d1, d2, d3 + c3, d4]
after 3 iterations
GPUA: [a1 + b1 + c1 + d1, - , - , - ]
GPUB: [ - , a1 + b1 + c1 + d1 , - , - ]
GPUC: [ - , - , a1 + b1 + c1 + d1 , - ]
GPUD: [ - , - , - , a1 + b1 + c1 + d1]
templates classification pipeline integration research kar raha tha
long meeting (over an hour meetings) system design approaches
edit document with bella agent responses and chunk sizes tweak