Model training and optimization
Understanding GPU Memory Anatomy (Hugging Face Guide), referring to the reference: https://huggingface.co/docs/transformers/model_memory_anatomy
First, the background of model training anatomy can be found from the huggingface blog. Training large language models is not only about compute — it is fundamentally a memory management problem.
Most training failures (OOM, instability, slow throughput) happen because we misunderstand what actually lives in GPU memory during training.
This post explains the anatomy of training memory and the practical optimization techniques built around it.
1. The Anatomy of GPU Memory During Training
According to the Hugging Face documentation, GPU memory during training is composed of the following components:
| Component | Purpose |
|---|---|
| Model weights | Parameters of the neural network |
| Optimizer states | Momentum / variance statistics (e.g., Adam) |
| Gradients | Stored during backpropagation |
| Forward activations | Saved tensors for gradient computation |
| Temporary buffers | Workspace for kernels |
| Functionality-specific memory | KV cache, attention masks, etc. |
Understanding how each scales with model size and batch size is the key to scaling training.
2. How Memory Scales
Model Weights
For a model with parameter count N:
- FP32:
4Nbytes - FP16/BF16:
2Nbytes - 8-bit:
1Nbytes - 4-bit:
0.5Nbytes
Weights are static — they do not depend on batch size.
Optimizer States
Adam keeps two statistics (moment + variance):
$$
Memory \approx 2 \plus weights
$$
So Adam training roughly costs:
weights + gradients + optimizer ≈ 4× model size (fp16) or 8× (fp32)
This is usually the largest fixed cost.
Gradients
Gradients have the same size as weights:
$$
Memory \approx weights
$$
They exist only during backward pass but must fit in memory.
Forward Activations (The Hidden Giant)
Activations scale with:
$$
Batch × Sequence Length × Hidden Size × Layers
$$
This is the only component that grows with batch size and often dominates memory usage for long context training.
Temporary Buffers
Used by kernels such as matrix multiplication and attention.
Typically small but spikes during attention computation.
Functionality-Specific Memory
Examples:
- attention masks
- KV cache
- intermediate logits
- layernorm stats
Often overlooked but important for long context training.
3. Why Training Runs Out of Memory
Most people assume weights dominate memory.
In practice:
For large sequence length, activations dominate.
For small batch, optimizer dominates.
So optimization strategies depend on the bottleneck type.
4. Optimization Techniques (Mapped to Memory Type)
Reduce Weight Memory
Low precision training
- FP16 / BF16
- 8-bit optimizer
- 4-bit quantization (QLoRA)
Parameter-efficient tuning
- LoRA
- adapters
These target the static memory cost.
Reduce Optimizer Memory
ZeRO Stage 1 / 2
- shard optimizer states
8-bit optimizers
- compress statistics
Best for multi-GPU setups.
Reduce Gradient Memory
Gradient accumulation
- smaller microbatch
FSDP / ZeRO Stage 2
- shard gradients across devices
Reduce Activation Memory (Most Important)
Gradient checkpointing
- recompute forward during backward
Tradeoff:
$$
\text{Less Memory} \leftrightarrow \text{More compute}
$$
Sequence packing
- better token utilization
Flash Attention
- compute attention without storing full matrix
This is the main technique enabling long-context training.
Reduce Temporary Buffers
Fused kernels
- FlashAttention
- xFormers
Improve both memory and speed.
Offloading
Move tensors to CPU/NVMe:
- ZeRO Offload
- FSDP CPU offload
Good when GPU memory is limited but bandwidth is available.
5. Practical Mental Model
During training, memory roughly divides into:
| Category | Scales With |
| Weights | Model size |
| Optimizer | Model size |
| Gradients | Model size |
| Activations | Batch × Context |
Therefore:
- Small model, long context → optimize activations
- Huge model, short context → optimize optimizer states
6. Final Takeaway
Training large models is a balance between three resources:
$$
Memory \leftrightarrow Compute \leftrightarrow Communication
$$
Most modern techniques work by trading one for another:
- Checkpointing trades compute for memory
- Sharding trades communication for memory
- Quantization trades precision for memory
Understanding the anatomy of memory makes training predictable instead of trial-and-error.