Model training and optimization

Understanding GPU Memory Anatomy (Hugging Face Guide), referring to the reference: https://huggingface.co/docs/transformers/model_memory_anatomy

First, the background of model training anatomy can be found from the huggingface blog. Training large language models is not only about compute — it is fundamentally a memory management problem.
Most training failures (OOM, instability, slow throughput) happen because we misunderstand what actually lives in GPU memory during training.

This post explains the anatomy of training memory and the practical optimization techniques built around it.

1. The Anatomy of GPU Memory During Training

According to the Hugging Face documentation, GPU memory during training is composed of the following components:

ComponentPurpose
Model weightsParameters of the neural network
Optimizer statesMomentum / variance statistics (e.g., Adam)
GradientsStored during backpropagation
Forward activationsSaved tensors for gradient computation
Temporary buffersWorkspace for kernels
Functionality-specific memoryKV cache, attention masks, etc.

Understanding how each scales with model size and batch size is the key to scaling training.


2. How Memory Scales

Model Weights

For a model with parameter count N:

  • FP32: 4N bytes
  • FP16/BF16: 2N bytes
  • 8-bit: 1N bytes
  • 4-bit: 0.5N bytes

Weights are static — they do not depend on batch size.


Optimizer States

Adam keeps two statistics (moment + variance):

$$
Memory \approx 2 \plus weights
$$

So Adam training roughly costs:

weights + gradients + optimizer ≈ 4× model size (fp16) or 8× (fp32)

This is usually the largest fixed cost.


Gradients

Gradients have the same size as weights:

$$
Memory \approx weights
$$

They exist only during backward pass but must fit in memory.


Forward Activations (The Hidden Giant)

Activations scale with:

$$
Batch × Sequence Length × Hidden Size × Layers
$$

This is the only component that grows with batch size and often dominates memory usage for long context training.


Temporary Buffers

Used by kernels such as matrix multiplication and attention.

Typically small but spikes during attention computation.

Functionality-Specific Memory

Examples:

  • attention masks
  • KV cache
  • intermediate logits
  • layernorm stats

Often overlooked but important for long context training.


3. Why Training Runs Out of Memory

Most people assume weights dominate memory.

In practice:

For large sequence length, activations dominate.
For small batch, optimizer dominates.

So optimization strategies depend on the bottleneck type.


4. Optimization Techniques (Mapped to Memory Type)

Reduce Weight Memory

Low precision training

  • FP16 / BF16
  • 8-bit optimizer
  • 4-bit quantization (QLoRA)

Parameter-efficient tuning

  • LoRA
  • adapters

These target the static memory cost.


Reduce Optimizer Memory

ZeRO Stage 1 / 2

  • shard optimizer states

8-bit optimizers

  • compress statistics

Best for multi-GPU setups.


Reduce Gradient Memory

Gradient accumulation

  • smaller microbatch

FSDP / ZeRO Stage 2

  • shard gradients across devices

Reduce Activation Memory (Most Important)

Gradient checkpointing

  • recompute forward during backward

Tradeoff:

$$
\text{Less Memory} \leftrightarrow \text{More compute}
$$

Sequence packing

  • better token utilization

Flash Attention

  • compute attention without storing full matrix

This is the main technique enabling long-context training.


Reduce Temporary Buffers

Fused kernels

  • FlashAttention
  • xFormers

Improve both memory and speed.


Offloading

Move tensors to CPU/NVMe:

  • ZeRO Offload
  • FSDP CPU offload

Good when GPU memory is limited but bandwidth is available.


5. Practical Mental Model

During training, memory roughly divides into:

CategoryScales With
WeightsModel size
OptimizerModel size
GradientsModel size
ActivationsBatch × Context

Therefore:

  • Small model, long context → optimize activations
  • Huge model, short context → optimize optimizer states

6. Final Takeaway

Training large models is a balance between three resources:

$$
Memory \leftrightarrow Compute \leftrightarrow Communication
$$

Most modern techniques work by trading one for another:

  • Checkpointing trades compute for memory
  • Sharding trades communication for memory
  • Quantization trades precision for memory

Understanding the anatomy of memory makes training predictable instead of trial-and-error.