Choosing a Distributed Training Strategy When One GPU Isn't Enough

March 20, 2028 · 16 min read

ML Engineer · MLA-C01 · part of The Exam Room

The situation

A computer-vision team has trained a Vision TransformerThe neural network architecture that underpins modern LLMs – stacks of self-attention layers that let every token look at every other token in the context. on product imagery for eighteen months. Until recently, training meant one p4d.24xlarge instance, 8 × A100 40GB, running 12 hours overnight and producing a new checkpoint by morning. Three things have changed.

  • The training set has grown from 10M images to 120M. A single-instance epoch that used to take 90 minutes now takes 18 hours.
  • The model has grown from 600M parameters to 3.2B. At 3.2B parameters in float32, the model weights alone are ~13GB, plus optimizer state (AdamW doubles it) and activations. On a 40GB A100, after activations and gradients there isn’t room for the full optimizer state.
  • The team has budget for up to 64 A100s across 8 instances, but the existing training script is a vanilla single-process PyTorch loop and doesn’t know what to do with them.

The question is which distributed-training strategy fits this problem, and what the code and infrastructure changes actually look like. “Just use more GPUs” isn’t a plan; there are at least four different ways to use more GPUs and they solve different problems.

What actually matters

Before reaching for a library, it’s worth naming the two pressures that force distribution in the first place.

The first is dataset size. If one epoch over the full dataset takes too long, the fix is to process different shards of the data on different GPUs in parallel, each holding a full copy of the model, each computing gradients on its own batch, and then synchronising gradients across ranks so that every replica converges to the same weights. This is data parallelism. It’s the easiest form of distribution, the model is unchanged, the per-GPU batch shrinks, the total throughput scales roughly linearly up to the point where gradient synchronisation becomes the bottleneck. Works beautifully when the model fits on one GPU.

The second is model size. If the model’s weights (or weights + optimizer state + activations) don’t fit on one GPU, no amount of data parallelism helps, you can’t replicate a model that doesn’t fit. The fix is to split the model itself across multiple GPUs. There are two orthogonal ways to split:

  • Pipeline parallelism: different layers of the model live on different GPUs. A mini-batch flows forward through layers 1-8 on GPU 0, then 9-16 on GPU 1, and so on. To keep all GPUs busy, the batch is split into micro-batches and pipelined, so while GPU 1 is processing micro-batch 2 from layers 9-16, GPU 0 is processing micro-batch 3 from layers 1-8.
  • Tensor parallelism: individual operations are split across GPUs. A matrix multiplication that’s too big for one GPU has its rows or columns sharded, each GPU computes a partial result, and the partial results are combined (all-reduce or all-gather) before the next operation. Commonly done within a single node where GPU-to-GPU bandwidth is very high; not as effective across nodes.

A middle path between these two pressures is sharded data parallelism: conceptually data-parallel, but instead of every replica holding a full copy of the model, weights, gradients, and optimizer state, the model is sharded across the replicas. Each rank holds 1/N of the parameters at any given time; when a layer needs to execute, the shard is gathered from the other ranks, the layer runs, and the shard is released. You get the throughput shape of data parallelism and the memory shape of model parallelism, at the cost of extra communication.

Which pressure binds here? Both. The model plus optimizer state doesn’t fit on one GPU, and the dataset is far too large for a single replica to grind through in a useful window. Strategies that solve only one of those leave the other unsolved.

What we’ll filter on

Four filters that separate the strategies:

  1. What’s split, the data, the model’s layers, the model’s tensors, or the model’s parameters-and-optimizer-state?
  2. Does the full model fit on one GPU, if yes, data parallelism is sufficient; if no, something model-parallel is required.
  3. Communication cost, gradient all-reduce (data-parallel), pipeline bubbles (pipeline-parallel), all-gather per layer (sharded), all-reduce within op (tensor-parallel).
  4. Code changes required, annotation, wrapping, partitioning, or re-architecture.

The distributed-training landscape

1. Single-GPU / single-instance. One process, one GPU, one model, one data loader. Fine up to whatever fits. The baseline every other strategy is measured against.

2. Data parallelism (DDP). N processes, each holding a full copy of the model on its own GPU. Each process gets a different shard of the batch. Forward and backward run independently; after the backward pass, gradients are all-reduced across ranks so every replica ends the step with the same (averaged) gradient and steps the optimizer identically. Scales throughput roughly linearly with N, up to the point gradient sync dominates, typically tens to low hundreds of GPUs for vision and NLP models. PyTorch ships torch.nn.parallel.DistributedDataParallel; SageMaker ships SMDDP (SageMaker Distributed Data Parallel), an AWS-optimised all-reduce using EFA and the instance’s NVLink topology. Code change: wrap the model in DDP, configure a sampler that shards the dataset.

3. Sharded data parallelism (FSDP / ZeRO). Data-parallel in intent, but the model weights + gradients + optimizer state are sharded across ranks. At forward time, each layer’s full parameters are gathered from peer ranks, the layer runs, and the memory for that shard is freed. At backward, the same gathering happens in reverse. The memory savings are large: with N ranks, per-rank memory for model state is ~1/N of the full-replica cost. Communication cost is higher than plain DDP because of the per-layer all-gather. PyTorch ships torch.distributed.fsdp.FullyShardedDataParallel; SageMaker Model Parallel Library wraps it with instance-topology-aware scheduling. Code change: wrap modules in FSDP with a sharding policy.

4. Pipeline parallelism. Model is partitioned into stages of layers; each stage lives on a different device. A mini-batch is chopped into micro-batches and pipelined: stage 0 processes micro-batch 1 while stage 1 processes the output of stage 0’s previous micro-batch. At steady state all stages are busy; startup and drain create “bubbles” of idle GPU time. Scales to very large models when paired with other parallelisms. Code change: explicit partition annotations (where to cut the model) and a pipeline scheduler.

5. Tensor parallelism. Individual operations (matmuls, attention heads) split across GPUs; each GPU holds a slice of the weight matrix and computes a slice of the output, with all-reduce or all-gather to combine. Communication is tight, generally only practical within a single node’s GPUs connected by NVLink, where bandwidth is high enough to keep up with the per-operation sync. Common for very large language models (Megatron-LM partitions attention and MLP blocks this way). Code change: custom layer implementations or a library that provides tensor-parallel versions.

6. Hybrid / 3D parallelism. Combine data parallelism (across instances), pipeline parallelism (across nodes within a replica), and tensor parallelism (within a node). Standard at the 100B+ parameter scale. Configuration is a small grid: how many instances for data parallelism, how many pipeline stages per replica, how many GPUs per tensor-parallel group.

7. SageMaker Training Compiler. Orthogonal to the above. Compiles the training graph ahead of time (XLA-based) so that individual GPUs spend less time in framework overhead. Multiplies with any distribution strategy. Useful when the bottleneck is per-GPU throughput rather than cross-GPU communication.

Side by side

Strategy What’s split Fits model > GPU Comm pattern Code change
Single-GPU Nothing None None (baseline)
Data parallel (DDP / SMDDP) Data batch ✗ (model must fit) Gradient all-reduce Wrap in DDP, shard sampler
Sharded data parallel (FSDP / ZeRO-3) Parameters + grads + optimizer state Per-layer all-gather + grad all-reduce Wrap in FSDP, sharding policy
Pipeline parallel Layers Activations forward, grads backward between stages Partition annotations, micro-batching
Tensor parallel Individual operations ✓ (weight-shape sense) All-reduce inside ops Library-provided layers
Hybrid 3D Data + layers + ops ✓ (any size) All of the above Combination of all of the above

Reading the table against the 3.2B ViT on 120M images scenario:

  • The model plus AdamW optimizer state and activations does not fit on a single 40GB A100. DDP alone won’t work, because you can’t replicate what doesn’t fit.
  • Pipeline parallelism alone is complex to shape-correctly for a ViT (transformers are easier than CNNs to pipeline, but the micro-batch scheduling still matters); not the first reach.
  • Tensor parallelism is overkill for 3.2B parameters on a single-node scale; it earns its keep at 20B+.
  • Sharded data parallelism is the fit. Parameters sharded across, say, 8 ranks (one full instance) means per-rank model-state memory is ~1/8 of the full-replica cost, which comfortably fits. Data parallelism across the remaining dimension scales throughput.

Picking the shape

Training slow / OOM start here Does the model fit on one GPU (with optimizer state)? yes Dataset too big for one GPU to cover? yes Data parallel (DDP / SMDDP) replicate model, shard data gradient all-reduce each step scales linearly to ~100s of GPUs no Does it fit sharded across one node? yes Sharded data parallel FSDP / ZeRO-3 weights + opt state sharded ~1/N memory per rank no Pipeline parallel layers across nodes micro-batches to hide bubbles + sharded DP on top Example ResNet-50 on ImageNet 600M-param model, 10M images 8 × A100 DDP → 5× faster Example (this case) 3.2B ViT, 120M images 8 nodes × 8 GPUs, FSDP sharded in-node, DP across Example 100B+ LLM hybrid 3D parallel DP + PP + TP
Two gates decide the strategy: does the model fit on one GPU, and does the dataset justify parallelism at all. For the ViT case, the answer is "no, and yes" — sharded data parallel across nodes is the fit.

The pick in depth

FSDP across 8 × p4d.24xlarge. Per-rank GPU count is 8 (one A100 per rank); 8 instances gives 64 ranks total. With full sharding (ZeRO-3 equivalent), each rank holds 1/64 of the parameters, gradients, and optimizer state at rest. When a given layer needs to run forward, the rank all-gathers the layer’s parameters from its peers, executes the layer, and drops the parameters again. On the backward pass, gradients for that layer are reduce-scattered back to the parameter-owning rank.

Per-rank memory for the 3.2B ViT works out to roughly (13GB weights + 26GB AdamW state) / 64 ≈ 0.6GB for model state, with activations and gradient buffers on top. That’s well inside the 40GB A100 and leaves room for a reasonable local micro-batch.

The SageMaker training-job shape:

from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    framework_version="2.3",
    py_version="py310",
    instance_type="ml.p4d.24xlarge",
    instance_count=8,
    distribution={
        "torch_distributed": {"enabled": True},
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "tensor_parallel_degree": 1,
                    "sharded_data_parallel_degree": 8,
                    "ddp": True,
                },
            }
        },
    },
    hyperparameters={
        "batch-size": 64,
        "learning-rate": 1e-4,
        "epochs": 30,
    },
    role=role,
)
estimator.fit({"train": "s3://dataset/train/", "val": "s3://dataset/val/"})

The smdistributed.modelparallel block configures SMP to use sharded data parallelism with degree 8 (within each node, the 8 GPUs collectively hold one sharded replica of the model). With 8 nodes, SageMaker automatically spans data parallelism across the 8 replicas. SMP wires up NCCL over EFA for the sharded all-gathers and gradient reductions.

Inside train.py, the PyTorch code uses the SageMaker SMP wrappers:

import smdistributed.modelparallel.torch as smp

smp.init()

model = smp.DistributedModel(VisionTransformer(...))
optimizer = smp.DistributedOptimizer(torch.optim.AdamW(model.parameters(), lr=1e-4))

@smp.step
def train_step(inputs, labels):
    outputs = model(inputs)
    loss = F.cross_entropy(outputs, labels)
    model.backward(loss)
    return loss

for epoch in range(epochs):
    for inputs, labels in dataloader:
        loss = train_step(inputs, labels)
        optimizer.step()
        optimizer.zero_grad()

Two wrappers: smp.DistributedModel shards the model’s parameters across the sharded-DP group; smp.DistributedOptimizer shards the optimizer state correspondingly. The @smp.step decoration lets the library control micro-batching and activation checkpointing.

When DDP would still be the answer. A ResNet-50 (~25M parameters) on ImageNet (1M images) fits on one A100 with room to spare. The team wants faster training. DDP with SMDDP across 8 GPUs on one p4d.24xlarge gets close to 8× speedup on the epoch time; no need for sharding. Adding sharding when the model fits pays the extra all-gather communication for no benefit.

When pipeline or tensor would be the answer. A 70B-parameter LLM does not fit sharded across 8 GPUs in a single node (the sharded per-rank memory is still larger than 40GB). The shape becomes pipeline across nodes, tensor within nodes, data across replicas. Hybrid 3D. Worth knowing the shape exists; for the vision case, sharded DP is the correct reach.

A worked scaling trace

Measuring the actual speedup is the real work. For the 3.2B ViT on 120M images:

  • 1 × p4d.24xlarge, single-GPU: cannot run. OOM in the first forward pass with AdamW optimizer state.
  • 1 × p4d.24xlarge, sharded across 8 GPUs: runs. Throughput ~1200 images/sec. Epoch time ~27h. Too slow for overnight.
  • 4 × p4d.24xlarge, sharded in-node + DDP across nodes: throughput ~4500 images/sec (3.75× scaling efficiency). Epoch time ~7.5h.
  • 8 × p4d.24xlarge, sharded in-node + DDP across nodes: throughput ~8200 images/sec (~6.8× scaling; some comm overhead begins showing). Epoch time ~4h. With 30 epochs, a full training run is ~5 days. Acceptable.
  • Adding SageMaker Training Compiler: further ~15% throughput at the same scale. Epoch ~3.4h.

EFA matters enormously at the 8-node scale; without it the inter-node all-reduce dominates. p4d.24xlarge includes 4 × 100Gbps EFA adapters; p4de.24xlarge and p5.48xlarge are even faster. The instance choice is part of the distribution strategy.

What’s worth remembering

  1. Two pressures force distribution. Dataset too big (data parallelism), model too big (model parallelism, sharded, pipeline, or tensor).
  2. DDP is the default when the model fits. Replicate the model on each GPU, shard the batch, all-reduce gradients. SMDDP is AWS’s EFA-optimised variant.
  3. FSDP / sharded data parallel is the default when the model doesn’t fit. Shard the parameters + gradients + optimizer state across ranks; gather per-layer on forward, reduce on backward. ~1/N memory per rank.
  4. Pipeline and tensor parallel handle very large models. Pipeline splits layers across devices with micro-batching; tensor splits individual ops within a node. Combine them for 100B+ models.
  5. Hybrid 3D is the 100B+ shape. Data across replicas, pipeline across nodes within a replica, tensor within a node. Configuration is a grid choice.
  6. SageMaker wraps the libraries. The distribution argument on the SDK’s estimator configures DDP, SMDDP, SMP, FSDP, or MPI-based alternatives. The training script uses the corresponding library’s wrappers.
  7. Instance and network choice are part of the strategy. p4d.24xlarge with EFA is the vision-transformer baseline; p5.48xlarge for bigger models and newer workloads; gradient all-reduce is bandwidth-bound at scale so the network matters.
  8. Training Compiler multiplies with distribution. Compile-time graph optimisation gives per-GPU speedup orthogonal to the distribution shape; worth enabling once the distribution is correct.

“Training is too slow” is not a single problem; it’s at least two (data too big, model too big). The strategy picks are different, the code changes are different, and the instance count is different. Getting the correct answer is a question of which pressure is the binding constraint, and the ViT case is one where both pressures bind, which is exactly what sharded data parallelism was designed for.

These posts are LLM-aided. Backbone, original writing, and structure by Craig. Research and editing by Craig + LLM. Proof-reading by Craig.