Skip to content

CS231n Lecture Note: Large Scale Distributed Training

GPUs

Modern AI training is structured as a hierarchy: individual GPUs (or TPUs) sit inside servers, servers are grouped into pods, pods into racks, and racks form a cluster.

Instead of training on a single machine, large neural networks are distributed across this entire cluster, which effectively acts as one coordinated system.

The key challenge is not just computation, but efficiently splitting the workload and synchronizing thousands of accelerators so they stay utilized, using techniques like data and model parallelism while minimizing communication overhead.

Training on GPUs

A model with L layers operates on tensors of shape (Batch, Sequence, Dim)

Split the compution into axes and we get:

  1. Data Parallelism (DP): Split on Batch dimension
  2. Context Parallelism (CP): Split on Sequence dimension
  3. Pipeline Parallelism (PP): Split on L dimension
  4. Tensor Parallelism (TP): Split on Dim dimension

Data Parallelism

In standard data parallelism, a minibatch of N samples is split across M GPUs, so each GPU processes roughly N/M samples. Every GPU holds a full copy of the model parameters.

During the forward and backward pass, each GPU computes gradients independently on its local subset of data. Because the loss is additive over the batch, gradients are linear, so the correct global gradient is obtained by averaging (or summing) gradients across all GPUs.

After gradient synchronization (typically via all-reduce), all GPUs update their local model copies identically, keeping them in sync.

The main limitation is memory: since each GPU must store the full model, the maximum model size is constrained by the memory of a single GPU, regardless of how many GPUs are used.

Fully Sharded Data Parallelism (FSPD)

Fully Sharded Data Parallelism (FSDP), as implemented in systems like PyTorch FSDP, shards model parameters, gradients, and optimizer states across GPUs so that each device only holds a fraction of the full model. Instead of replicating the entire model on every GPU, each worker owns a shard of each parameter tensor, which significantly reduces memory usage.

During the forward pass, parameters are reconstructed on demand using an all-gather operation across GPUs. The full weights are only materialized temporarily for computation and are freed immediately after use. To reduce communication overhead, FSDP overlaps this process with computation by prefetching parameters for upcoming layers.

In the backward pass, gradients are computed using the temporarily gathered parameters and then redistributed using a reduce-scatter operation. This ensures each GPU keeps only its corresponding shard of the gradients rather than the full gradient tensor.

Finally, the optimizer step is performed locally on each GPU using its shard of parameters, gradients, and optimizer states. As a result, no GPU ever needs to hold the full model persistently, reducing memory complexity from O(N) to approximately O(N / k) across k GPUs, at the cost of additional but carefully managed communication.

Hybrid Sharded Data Parallel (HSDP)

In Hybrid Sharded Data Parallelism, the total number of GPUs N is organized into a 2D structure such that N=M×KN = M \times K. The GPUs are divided into M groups, each containing K GPUs.

Within each group of K GPUs, Fully Sharded Data Parallelism (FSDP) is applied. This means model parameters, gradients, and optimizer states are sharded across the K GPUs in the group. No single GPU holds the full model; instead, the group collectively represents one full model in a distributed form.

Across the M groups, standard data parallelism is used. Each group processes a different subset of the minibatch, computes gradients independently, and then synchronizes gradients across groups to ensure all replicas remain consistent.

HSDP is a concrete example of multidimensional parallelism, where different parallelization strategies are applied along different axes of a logical device grid. In this case, GPUs are arranged in a 2D grid: one dimension for sharding (FSDP) and one for replication (data parallelism).

In practice, large-scale training systems often extend this idea further by combining additional dimensions such as tensor parallelism (splitting computations within layers) and pipeline parallelism (splitting layers across stages), forming higher-dimensional parallel training schemes.

Activation Checkpointing

Activation checkpointing reduces memory usage by saving only a subset of intermediate activations during the forward pass and recomputing the missing ones during the backward pass.

Instead of storing activations for every layer, the model saves checkpoints every C layers and discards the rest; when gradients are needed, it recomputes activations starting from the nearest checkpoint.

This significantly lowers activation memory at the cost of extra computation, creating a trade-off where fewer checkpoints save more memory but require more recomputation.

It’s common to set C=NC = \sqrt{N}.

Context Parallelism (CP)

Commonly used for Transformer models

Transformers operate on sequences of length LL (or SS). Context Parallelism involves using multiple GPUs to process a single long sequence that would otherwise be too large for one device’s memory.

The Normalization & Residual Connections have no weights and are easy to parallize. The MLP and QKV Projections are similar to DP.

The Attention mechanism is the hardest part to parallelize because every element in the sequence needs to look at every other element. There are two primary options:

Option 1: Ring Attention

  • Approach: Divide the sequence into blocks and distribute them over GPUs.
  • Execution: Uses an inner loop over keys/values and an outer loop over queries.
  • Verdict: Complex to implement but highly scalable for extremely long sequences.

Option 2: Ulysses

  • Approach: Instead of distributing the attention matrix itself, it parallelizes over the heads in Multi-Head Attention.
  • Verdict: Simpler to implement than Ring Attention.
  • Constraint: Maximum parallelism is limited by the number of attention heads (Max Parallelism = NheadsN_{heads}).

Pipeline Parallelism (PP)

Split the layers of the model across GPUs. Copy activations between layers at GPU boundaries.

To avoid “Pipeline Bubbles” (where GPUs sit idle waiting for data), the model runs multiple micro-batches simultaneously.

Tensor Parallelism (TP)

Split the weights of each linear layer across GPUs, use block matrix multiply.

With 2 consecutive TP layers, shard first over row and second over column to avoid communication.

Benchmarking Parallelism

Hardware FLOPs Utilization (HFU): The fraction of theoretical matmul performance we actually achieve.

We benchmark for the best-case scenario for HFU. But this doesn’t account for other computation like checkpointing, data preprocessing, etc.

Model FLOPs Utilization (MFU): the fraction of the GPU’s theoretical peak FLOPs used for “useful” model computation.

  1. Compute FLOPtheoreticalFLOP_{\text{theoretical}}

    • This is the total number of matrix multiply FLOPs in the forward and backward pass.
    • Heuristic: You can approximate the backward pass as 2x the forward pass.
    • Note: Ignore non-linearities, normalization, and elementwise operations (like residuals). These typically run on FP32 cores and do not significantly contribute to the primary matrix math calculation.
  2. Look up FLOP/sectheoreticalFLOP/\text{sec}_{\text{theoretical}}

    • Find the theoretical maximum throughput of your specific hardware.
  3. Compute ttheoreticalt_{\text{theoretical}}

    • Calculate the ideal time a pass should take if the GPU was running at 100% efficiency:

ttheoretical=FLOPtheoreticalFLOP/sectheoreticalt_{\text{theoretical}} = \frac{FLOP_{\text{theoretical}}}{FLOP/\text{sec}_{\text{theoretical}}}

  1. Measure tactualt_{\text{actual}}

    • This is the real-world time measured for a full iteration.
    • Includes: Data loading, forward pass, backward pass, and the optimizer step.
  2. Calculate MFU

    • The final utilization ratio:

MFU=ttheoreticaltactual\text{MFU} = \frac{t_{\text{theoretical}}}{t_{\text{actual}}}

MFU > 30% is good, >40% is excellent.

ND Parallelism

In practice, we use Use TP, CP, PP, and DP all at the same time. GPUs are arranged in a 4D grid. We tune and optimize the setup to maximize the MFU.

About this Post

This post is written by Louis C Deng, licensed under CC BY-NC 4.0.

#CS231n #Deep Learning