The situation
A small research team trains a LLMA neural network trained to predict the next token in a sequence, large enough that it generalises to tasks it wasn’t explicitly trained for. using JAX, Flax, and Optax. Today’s setup:
- Four
p4de.24xlargeinstances running a home-grown Docker image with JAX, Flax, Optax, and a few bespoke dependencies. - Data in S3, pulled at training start via
s5cmd. - Checkpoints written back to S3 via JAX’s
orbaxlibrary. - Hyperparameter tuning done manually, a shell loop over
docker runinvocations with different flags. - Inference done by the same container running a small HTTP server, with TLS terminated by an ALB in front.
They want to move to SageMaker for three reasons: the team is small enough that managed training and hyperparameter tuning would free real engineering time; SageMaker Model Registry would replace a set of S3 paths and a spreadsheet; and deployment to endpoints with autoscaling would replace the hand-rolled server.
The blocker: SageMaker’s list of AWS Deep Learning Containers doesn’t include JAX. The options are use a framework they don’t use (wrong answer), switch frameworks (very wrong answer), or bring their own container. The question is how to build that container so it plays correctly with SageMaker’s training jobs, endpoints, and tuner.
What actually matters
SageMaker’s managed services, training jobs, Processing jobs, endpoints, batch transform, the Model Registry, the hyperparameter tuner, are all built on the same underlying primitive: containers that follow a specific contract. When AWS ships a first-party DLC, it’s because AWS wrote the contract implementation for that framework. Bringing your own container is, concretely, writing the same contract implementation for a framework AWS hasn’t packaged.
The first thing worth being clear about is which contract we’re implementing. There isn’t one. There are several, all using the same Docker image but with different expectations depending on how SageMaker invokes it.
For training jobs, the contract is:
- The image’s default command (or an override via
sagemaker-training-toolkit) must accept training hyperparameters as command-line arguments or as/opt/ml/input/config/hyperparameters.json. - Input data channels are mounted read-only at
/opt/ml/input/data/<channel-name>/(e.g./opt/ml/input/data/train/). - The training script writes model artefacts to
/opt/ml/model/; SageMaker tars up whatever is there at the end and uploads it to S3. - Failures write to
/opt/ml/output/failure; a non-zero exit code also marks the job failed. - stdout/stderr stream to CloudWatch Logs automatically (SageMaker’s doing, not the container’s).
For inference, the contract is:
- The image must run an HTTP server listening on port 8080 (or 8081 for multi-model).
- Two endpoints are required:
GET /pingreturns 200 when the container is healthy (SageMaker uses this to decide when the endpoint is ready), andPOST /invocationsaccepts the payload, runs inference, and returns the response. - The container must start and serve a successful
/pingwithin a timeout (60s default for deploy-time health check; configurable). - The model artefact is available at
/opt/ml/model/, the tar uploaded from training is automatically unpacked here at container start.
For Processing jobs, the contract is simpler:
- Input channels mount at
/opt/ml/processing/input/<channel-name>/(read-only); outputs go to/opt/ml/processing/output/<channel-name>/(write, auto-uploaded on exit). - The container’s entrypoint runs to completion; exit code signals success/failure.
A container can support all three contracts simultaneously, which is the usual pattern for “one image for training and inference.” Or it can support just one, a training-only container plus a separate inference container is also common.
The second thing worth naming is where the image goes. SageMaker pulls images from ECR (Amazon Elastic Container Registry) in the same account and Region as the job. Cross-account or cross-region is possible with extra IAM work but adds latency and complication. The first migration task is picking an ECR repository, pushing the image there, and granting the SageMaker execution role permission to pull.
The third is what AWS already gives you. The sagemaker-training-toolkit Python package, installed into a custom image, handles the “read hyperparameters from /opt/ml/input/config/hyperparameters.json, find the user’s training script, invoke it, capture stdout/stderr” pieces of the training contract. The sagemaker-inference-toolkit does similar for inference (wraps a model_fn, input_fn, predict_fn, output_fn pattern in a Flask app). You can either use these toolkits inside a custom base image and focus on dependencies, or implement the contracts yourself. The toolkits are thin, a few hundred lines of Python, so rolling your own is reasonable when the dependencies don’t tolerate their conventions.
What we’ll filter on
Four filters worth applying to any bring-your-own-container choice:
- Which contracts, training, inference, processing, or some combination?
- Toolkit or hand-roll, use
sagemaker-*-toolkitfor the plumbing, or write the contract implementation directly? - Base image, start from an AWS DLC (get CUDA, drivers, EFA bindings for free) or from scratch (full control, more work)?
- How the script is loaded, baked into the image at build time, or pulled from S3 at runtime via
source_dir?
The bring-your-own landscape
1. Use a pre-built AWS DLC. The default. PyTorch, TensorFlow, MXNet, HuggingFace, XGBoost, scikit-learn images are maintained by AWS, pre-configured for CUDA/cuDNN/EFA on the correct instance types, and updated regularly. Custom code goes in via source_dir (an S3 path or local folder that SageMaker uploads); no image build required for most tuning. The easiest path; the correct answer when the framework is on the supported list.
2. DLC as base image + FROM. Start from 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.3.0-gpu-py310-cu121-ubuntu22.04-sagemaker (or whatever framework is closest), FROM that in a custom Dockerfile, layer additional dependencies on top. Keeps AWS’s CUDA/driver/EFA configuration; adds whatever the team needs. Most bring-your-own work for teams using “supported framework + extra packages” lands here.
3. Full bring-your-own container with sagemaker-*-toolkit. Start from a base (Ubuntu, nvidia/cuda, whatever), install the framework (JAX, say), install sagemaker-training-toolkit and sagemaker-inference-toolkit, and write a training script plus model_fn/predict_fn/input_fn/output_fn. Toolkits handle the contract plumbing; the team handles the framework-specific code. The path the JAX team is on.
4. Full bring-your-own container, hand-rolled contracts. No toolkits; implement the contracts from scratch. Read hyperparameters from /opt/ml/input/config/hyperparameters.json, iterate through /opt/ml/input/data/<channel>/, write to /opt/ml/model/; for inference, write a Flask/FastAPI server with /ping and /invocations routes. More code, full control. Appropriate when the framework’s process model doesn’t play well with the toolkits’ Flask-based server, e.g. for gRPC-first frameworks or for multi-process inference setups.
5. AWS Marketplace model packages. Pre-built containers sold or shared through the Marketplace. Useful for consuming someone else’s model; not the path for bringing your own.
Side by side
| Approach | Contracts covered | Toolkit | Base image | Script location |
|---|---|---|---|---|
| Pre-built DLC | All | ✓ (built in) | DLC | source_dir (S3 at runtime) |
| DLC + layered | All | ✓ (inherited) | DLC | source_dir or baked |
| BYO + toolkit | Training and/or inference | ✓ (installed) | Any (Ubuntu, nvidia/cuda) | source_dir or baked |
| BYO + hand-rolled | Any (you implement) | ✗ | Any | Baked (usually) |
| Marketplace | Consumer-side | n/a | n/a | n/a |
Reading the table against the JAX scenario: BYO with sagemaker-training-toolkit and sagemaker-inference-toolkit is the correct fit. The toolkits handle the SageMaker-specific parts (hyperparameter parsing, channel iteration, /ping and /invocations routing) while the team writes ordinary JAX/Flax training and inference code. Hand-rolling would be more work for no benefit.
The container contract
The pick in depth
BYOC with both toolkits, DLC-style base image. The Dockerfile for the JAX team looks roughly like this:
# Start from an NVIDIA CUDA image matching the target instance
FROM nvcr.io/nvidia/cuda:12.3.1-cudnn9-devel-ubuntu22.04
# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.11 python3.11-venv python3-pip \
build-essential git && \
rm -rf /var/lib/apt/lists/*
# Python deps -- JAX on GPU, Flax, Optax, plus the SageMaker toolkits
RUN pip install --upgrade pip && \
pip install \
"jax[cuda12]>=0.4.30" \
"flax>=0.8" \
"optax>=0.2" \
"orbax-checkpoint>=0.5" \
sagemaker-training \
sagemaker-inference \
"boto3>=1.34"
# SageMaker expects the container to find the user's script at runtime;
# the training toolkit handles this when invoked via `sagemaker-training`.
ENV PYTHONUNBUFFERED=1 \
SAGEMAKER_PROGRAM=train.py
# For inference, the inference toolkit's entrypoint wraps a user-supplied
# inference.py with model_fn, input_fn, predict_fn, output_fn.
ENV SAGEMAKER_SERVING_PROGRAM=inference.py
# Default command: training. Inference mode is selected by SageMaker when
# the image is used as an endpoint (`serve` is invoked automatically).
ENTRYPOINT ["python3"]
The training script (train.py) is ordinary JAX:
# train.py
import os, json
import jax, jax.numpy as jnp
from flax import linen as nn
import optax
import orbax.checkpoint as ocp
def read_hyperparameters():
with open("/opt/ml/input/config/hyperparameters.json") as f:
return json.load(f)
def main():
hp = read_hyperparameters()
train_data_dir = "/opt/ml/input/data/train"
val_data_dir = "/opt/ml/input/data/val"
model_dir = "/opt/ml/model"
# ... ordinary JAX training loop using the above paths ...
ckpt = ocp.PyTreeCheckpointer()
ckpt.save(os.path.join(model_dir, "final"), params)
if __name__ == "__main__":
main()
Three pieces worth noticing:
- Paths are hard-coded to SageMaker’s conventions.
/opt/ml/input/data/train,/opt/ml/input/data/val,/opt/ml/model,/opt/ml/input/config/hyperparameters.json. The training toolkit can parse hyperparameters intosys.argvstyle args if preferred; direct JSON read is also fine. - Output goes to
/opt/ml/model. When the script exits cleanly, SageMaker tars the directory and writes it to S3 at the job’s configured output path. - No HTTP server in the training path. Training is a run-to-completion contract; process exits, SageMaker takes over.
The inference script (inference.py) uses the inference toolkit’s four-function pattern:
# inference.py
import os, json
import jax, jax.numpy as jnp
from flax import linen as nn
import orbax.checkpoint as ocp
_model = None
_params = None
def model_fn(model_dir):
"""Called once at container startup. Loads the model artefact."""
global _model, _params
_model = MyFlaxModel(...)
ckpt = ocp.PyTreeCheckpointer()
_params = ckpt.restore(os.path.join(model_dir, "final"))
return (_model, _params)
def input_fn(request_body, content_type):
"""Deserialise the request payload."""
if content_type == "application/json":
payload = json.loads(request_body)
return jnp.asarray(payload["inputs"])
raise ValueError(f"Unsupported content type: {content_type}")
def predict_fn(inputs, model):
model_, params_ = model
return model_.apply({"params": params_}, inputs)
def output_fn(prediction, accept):
return json.dumps({"predictions": prediction.tolist()}), accept
The inference toolkit wires these four together into a Flask app, routes /invocations through input_fn → predict_fn → output_fn, and implements /ping as a 200 once model_fn has returned successfully. The team doesn’t write any Flask code.
Push to ECR, register a model. Once the image builds:
aws ecr create-repository --repository-name jax-flax-byoc
docker tag jax-flax-byoc:latest \
111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09
aws ecr get-login-password | docker login --username AWS --password-stdin \
111122223333.dkr.ecr.eu-west-1.amazonaws.com
docker push 111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09
Then training becomes a normal SageMaker call:
from sagemaker.estimator import Estimator
estimator = Estimator(
image_uri="111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09",
role=role,
instance_type="ml.p4de.24xlarge",
instance_count=4,
hyperparameters={"learning-rate": 3e-4, "batch-size": 512, "epochs": 20},
entry_point="train.py",
source_dir="src/", # bundles train.py + utils; uploaded to the job at runtime
)
estimator.fit({"train": "s3://data/train", "val": "s3://data/val"})
And inference:
from sagemaker.model import Model
model = Model(
image_uri="111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09",
model_data=estimator.model_data,
role=role,
entry_point="inference.py",
source_dir="src/",
)
predictor = model.deploy(
initial_instance_count=1,
instance_type="ml.g5.xlarge",
endpoint_name="jax-flax-endpoint",
)
Hyperparameter tuning works the same way, construct a HyperparameterTuner with the estimator and a search space; SageMaker launches N training jobs in parallel against the BYOC image and picks the best.
A worked migration
The research team’s migration takes roughly a week, broken into:
- Day 1: Dockerfile and base image. Decide on CUDA 12.3, Ubuntu 22.04, Python 3.11; install JAX, Flax, Optax, toolkits; confirm the image runs training locally with the SageMaker paths mounted as volumes.
- Day 2: training contract. Refactor the training script to read from
/opt/ml/input/...and write to/opt/ml/model. Add thehyperparameters.jsonread path. Local-mode test with SageMaker SDK’slocalinstance type. - Day 3: ECR push + first cloud training run. Push the image, launch a small training job on one
p4de.24xlarge, confirm model artefact ends up in S3. - Day 4: inference contract. Implement
model_fn/input_fn/predict_fn/output_fn. Test the container locally by runningserveand hitting/pingand/invocations. - Day 5: endpoint deployment + first inference test. Deploy an endpoint, smoke-test it.
- Day 6: hyperparameter tuner + CloudWatch wiring. Define tuner, launch a small search, confirm Studio shows results.
- Day 7: Model Registry, CI for the image build, documentation.
One week to replace “four EC2 instances and a shell loop” with “SageMaker training, tuning, registry, endpoint.” The BYOC image is an investment paid once and amortised over every future run.
What’s worth remembering
- SageMaker’s managed services all run containers. Training jobs, Processing jobs, endpoints, batch transform, different contracts, same Docker image shape.
- Training contract: files and exit code. Read hyperparameters and channels from fixed paths; write artefacts to
/opt/ml/model; exit 0 for success. No HTTP server. - Inference contract: HTTP on :8080.
GET /pingreturns 200 when ready;POST /invocationsaccepts the payload and returns inference. Model artefact unpacked at/opt/ml/model. - The toolkits save the plumbing.
sagemaker-training-toolkitandsagemaker-inference-toolkitimplement the contract shells; you write framework code. Model-fn/input-fn/predict-fn/output-fn is the inference pattern. - ECR is where images live. One-time setup; SageMaker pulls from ECR in the same account and Region as the job.
- DLC as a base image is the easy win for supported-framework-plus-dependencies.
FROMthe DLC; layer extras on top; inherit the CUDA/driver/EFA configuration. - Use local mode to debug before paying for cloud.
instance_type="local"in the SDK runs the container on the laptop, same contract, tight iteration loop. - The BYOC image is a durable artefact. Build once, push to ECR, use in training jobs, tuners, endpoints, batch transform, Processing jobs. Versioned via image tag; the team treats it as a first-class deliverable.
BYOC is how you put a framework AWS hasn’t packaged onto SageMaker’s rails. The work is writing to the contract, not reinventing the service, and once the image is built, JAX, Flax, or any other unsupported framework behaves like a first-party one.