Bringing Your Own Container to SageMaker

March 27, 2028 · 17 min read

ML Engineer · MLA-C01 · part of The Exam Room

The situation

A small research team trains a LLMA neural network trained to predict the next token in a sequence, large enough that it generalises to tasks it wasn’t explicitly trained for. using JAX, Flax, and Optax. Today’s setup:

  • Four p4de.24xlarge instances running a home-grown Docker image with JAX, Flax, Optax, and a few bespoke dependencies.
  • Data in S3, pulled at training start via s5cmd.
  • Checkpoints written back to S3 via JAX’s orbax library.
  • Hyperparameter tuning done manually, a shell loop over docker run invocations with different flags.
  • Inference done by the same container running a small HTTP server, with TLS terminated by an ALB in front.

They want to move to SageMaker for three reasons: the team is small enough that managed training and hyperparameter tuning would free real engineering time; SageMaker Model Registry would replace a set of S3 paths and a spreadsheet; and deployment to endpoints with autoscaling would replace the hand-rolled server.

The blocker: SageMaker’s list of AWS Deep Learning Containers doesn’t include JAX. The options are use a framework they don’t use (wrong answer), switch frameworks (very wrong answer), or bring their own container. The question is how to build that container so it plays correctly with SageMaker’s training jobs, endpoints, and tuner.

What actually matters

SageMaker’s managed services, training jobs, Processing jobs, endpoints, batch transform, the Model Registry, the hyperparameter tuner, are all built on the same underlying primitive: containers that follow a specific contract. When AWS ships a first-party DLC, it’s because AWS wrote the contract implementation for that framework. Bringing your own container is, concretely, writing the same contract implementation for a framework AWS hasn’t packaged.

The first thing worth being clear about is which contract we’re implementing. There isn’t one. There are several, all using the same Docker image but with different expectations depending on how SageMaker invokes it.

For training jobs, the contract is:

  • The image’s default command (or an override via sagemaker-training-toolkit) must accept training hyperparameters as command-line arguments or as /opt/ml/input/config/hyperparameters.json.
  • Input data channels are mounted read-only at /opt/ml/input/data/<channel-name>/ (e.g. /opt/ml/input/data/train/).
  • The training script writes model artefacts to /opt/ml/model/; SageMaker tars up whatever is there at the end and uploads it to S3.
  • Failures write to /opt/ml/output/failure; a non-zero exit code also marks the job failed.
  • stdout/stderr stream to CloudWatch Logs automatically (SageMaker’s doing, not the container’s).

For inference, the contract is:

  • The image must run an HTTP server listening on port 8080 (or 8081 for multi-model).
  • Two endpoints are required: GET /ping returns 200 when the container is healthy (SageMaker uses this to decide when the endpoint is ready), and POST /invocations accepts the payload, runs inference, and returns the response.
  • The container must start and serve a successful /ping within a timeout (60s default for deploy-time health check; configurable).
  • The model artefact is available at /opt/ml/model/, the tar uploaded from training is automatically unpacked here at container start.

For Processing jobs, the contract is simpler:

  • Input channels mount at /opt/ml/processing/input/<channel-name>/ (read-only); outputs go to /opt/ml/processing/output/<channel-name>/ (write, auto-uploaded on exit).
  • The container’s entrypoint runs to completion; exit code signals success/failure.

A container can support all three contracts simultaneously, which is the usual pattern for “one image for training and inference.” Or it can support just one, a training-only container plus a separate inference container is also common.

The second thing worth naming is where the image goes. SageMaker pulls images from ECR (Amazon Elastic Container Registry) in the same account and Region as the job. Cross-account or cross-region is possible with extra IAM work but adds latency and complication. The first migration task is picking an ECR repository, pushing the image there, and granting the SageMaker execution role permission to pull.

The third is what AWS already gives you. The sagemaker-training-toolkit Python package, installed into a custom image, handles the “read hyperparameters from /opt/ml/input/config/hyperparameters.json, find the user’s training script, invoke it, capture stdout/stderr” pieces of the training contract. The sagemaker-inference-toolkit does similar for inference (wraps a model_fn, input_fn, predict_fn, output_fn pattern in a Flask app). You can either use these toolkits inside a custom base image and focus on dependencies, or implement the contracts yourself. The toolkits are thin, a few hundred lines of Python, so rolling your own is reasonable when the dependencies don’t tolerate their conventions.

What we’ll filter on

Four filters worth applying to any bring-your-own-container choice:

  1. Which contracts, training, inference, processing, or some combination?
  2. Toolkit or hand-roll, use sagemaker-*-toolkit for the plumbing, or write the contract implementation directly?
  3. Base image, start from an AWS DLC (get CUDA, drivers, EFA bindings for free) or from scratch (full control, more work)?
  4. How the script is loaded, baked into the image at build time, or pulled from S3 at runtime via source_dir?

The bring-your-own landscape

1. Use a pre-built AWS DLC. The default. PyTorch, TensorFlow, MXNet, HuggingFace, XGBoost, scikit-learn images are maintained by AWS, pre-configured for CUDA/cuDNN/EFA on the correct instance types, and updated regularly. Custom code goes in via source_dir (an S3 path or local folder that SageMaker uploads); no image build required for most tuning. The easiest path; the correct answer when the framework is on the supported list.

2. DLC as base image + FROM. Start from 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.3.0-gpu-py310-cu121-ubuntu22.04-sagemaker (or whatever framework is closest), FROM that in a custom Dockerfile, layer additional dependencies on top. Keeps AWS’s CUDA/driver/EFA configuration; adds whatever the team needs. Most bring-your-own work for teams using “supported framework + extra packages” lands here.

3. Full bring-your-own container with sagemaker-*-toolkit. Start from a base (Ubuntu, nvidia/cuda, whatever), install the framework (JAX, say), install sagemaker-training-toolkit and sagemaker-inference-toolkit, and write a training script plus model_fn/predict_fn/input_fn/output_fn. Toolkits handle the contract plumbing; the team handles the framework-specific code. The path the JAX team is on.

4. Full bring-your-own container, hand-rolled contracts. No toolkits; implement the contracts from scratch. Read hyperparameters from /opt/ml/input/config/hyperparameters.json, iterate through /opt/ml/input/data/<channel>/, write to /opt/ml/model/; for inference, write a Flask/FastAPI server with /ping and /invocations routes. More code, full control. Appropriate when the framework’s process model doesn’t play well with the toolkits’ Flask-based server, e.g. for gRPC-first frameworks or for multi-process inference setups.

5. AWS Marketplace model packages. Pre-built containers sold or shared through the Marketplace. Useful for consuming someone else’s model; not the path for bringing your own.

Side by side

Approach Contracts covered Toolkit Base image Script location
Pre-built DLC All ✓ (built in) DLC source_dir (S3 at runtime)
DLC + layered All ✓ (inherited) DLC source_dir or baked
BYO + toolkit Training and/or inference ✓ (installed) Any (Ubuntu, nvidia/cuda) source_dir or baked
BYO + hand-rolled Any (you implement) Any Baked (usually)
Marketplace Consumer-side n/a n/a n/a

Reading the table against the JAX scenario: BYO with sagemaker-training-toolkit and sagemaker-inference-toolkit is the correct fit. The toolkits handle the SageMaker-specific parts (hyperparameter parsing, channel iteration, /ping and /invocations routing) while the team writes ordinary JAX/Flax training and inference code. Hand-rolling would be more work for no benefit.

The container contract

Training contract CreateTrainingJob launches the image; image runs to completion SageMaker Training Job mounts channels, writes hyperparameters, starts container Your container default CMD or ENTRYPOINT runs training /opt/ml/input/config/hyperparameters.json (read) /opt/ml/input/data/train/ /opt/ml/input/data/val/ (read) /opt/ml/model/ (write — becomes model.tar.gz) /opt/ml/output/ (write — captured artefacts) exit 0 → success | non-zero → failure SageMaker tars /opt/ml/model/ → S3 stdout/stderr → CloudWatch Logs Inference contract CreateEndpoint starts the image; image stays up serving requests SageMaker Endpoint unpacks model.tar.gz into /opt/ml/model/, starts container Your container HTTP server on :8080, long-running /opt/ml/model/ (read — unpacked by SageMaker) GET /ping → 200 OK when ready used by SageMaker to decide endpoint readiness POST /invocations body: payload; content-type: as client sent response: inference result multi-model endpoints use port 8080 + MMS hooks Client calls InvokeEndpoint → SageMaker forwards to /invocations CloudWatch metrics: ModelLatency, Invocations, Invocation4XXErrors
Training and inference share an image but implement different contracts against SageMaker: training is run-to-completion with files at fixed paths; inference is a long-running HTTP server on :8080 serving /ping and /invocations.

The pick in depth

BYOC with both toolkits, DLC-style base image. The Dockerfile for the JAX team looks roughly like this:

# Start from an NVIDIA CUDA image matching the target instance
FROM nvcr.io/nvidia/cuda:12.3.1-cudnn9-devel-ubuntu22.04

# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
      python3.11 python3.11-venv python3-pip \
      build-essential git && \
    rm -rf /var/lib/apt/lists/*

# Python deps -- JAX on GPU, Flax, Optax, plus the SageMaker toolkits
RUN pip install --upgrade pip && \
    pip install \
      "jax[cuda12]>=0.4.30" \
      "flax>=0.8" \
      "optax>=0.2" \
      "orbax-checkpoint>=0.5" \
      sagemaker-training \
      sagemaker-inference \
      "boto3>=1.34"

# SageMaker expects the container to find the user's script at runtime;
# the training toolkit handles this when invoked via `sagemaker-training`.
ENV PYTHONUNBUFFERED=1 \
    SAGEMAKER_PROGRAM=train.py

# For inference, the inference toolkit's entrypoint wraps a user-supplied
# inference.py with model_fn, input_fn, predict_fn, output_fn.
ENV SAGEMAKER_SERVING_PROGRAM=inference.py

# Default command: training. Inference mode is selected by SageMaker when
# the image is used as an endpoint (`serve` is invoked automatically).
ENTRYPOINT ["python3"]

The training script (train.py) is ordinary JAX:

# train.py
import os, json
import jax, jax.numpy as jnp
from flax import linen as nn
import optax
import orbax.checkpoint as ocp

def read_hyperparameters():
    with open("/opt/ml/input/config/hyperparameters.json") as f:
        return json.load(f)

def main():
    hp = read_hyperparameters()
    train_data_dir = "/opt/ml/input/data/train"
    val_data_dir   = "/opt/ml/input/data/val"
    model_dir      = "/opt/ml/model"

    # ... ordinary JAX training loop using the above paths ...

    ckpt = ocp.PyTreeCheckpointer()
    ckpt.save(os.path.join(model_dir, "final"), params)

if __name__ == "__main__":
    main()

Three pieces worth noticing:

  1. Paths are hard-coded to SageMaker’s conventions. /opt/ml/input/data/train, /opt/ml/input/data/val, /opt/ml/model, /opt/ml/input/config/hyperparameters.json. The training toolkit can parse hyperparameters into sys.argv style args if preferred; direct JSON read is also fine.
  2. Output goes to /opt/ml/model. When the script exits cleanly, SageMaker tars the directory and writes it to S3 at the job’s configured output path.
  3. No HTTP server in the training path. Training is a run-to-completion contract; process exits, SageMaker takes over.

The inference script (inference.py) uses the inference toolkit’s four-function pattern:

# inference.py
import os, json
import jax, jax.numpy as jnp
from flax import linen as nn
import orbax.checkpoint as ocp

_model = None
_params = None

def model_fn(model_dir):
    """Called once at container startup. Loads the model artefact."""
    global _model, _params
    _model = MyFlaxModel(...)
    ckpt = ocp.PyTreeCheckpointer()
    _params = ckpt.restore(os.path.join(model_dir, "final"))
    return (_model, _params)

def input_fn(request_body, content_type):
    """Deserialise the request payload."""
    if content_type == "application/json":
        payload = json.loads(request_body)
        return jnp.asarray(payload["inputs"])
    raise ValueError(f"Unsupported content type: {content_type}")

def predict_fn(inputs, model):
    model_, params_ = model
    return model_.apply({"params": params_}, inputs)

def output_fn(prediction, accept):
    return json.dumps({"predictions": prediction.tolist()}), accept

The inference toolkit wires these four together into a Flask app, routes /invocations through input_fnpredict_fnoutput_fn, and implements /ping as a 200 once model_fn has returned successfully. The team doesn’t write any Flask code.

Push to ECR, register a model. Once the image builds:

aws ecr create-repository --repository-name jax-flax-byoc
docker tag jax-flax-byoc:latest \
  111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09
aws ecr get-login-password | docker login --username AWS --password-stdin \
  111122223333.dkr.ecr.eu-west-1.amazonaws.com
docker push 111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09

Then training becomes a normal SageMaker call:

from sagemaker.estimator import Estimator

estimator = Estimator(
    image_uri="111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09",
    role=role,
    instance_type="ml.p4de.24xlarge",
    instance_count=4,
    hyperparameters={"learning-rate": 3e-4, "batch-size": 512, "epochs": 20},
    entry_point="train.py",
    source_dir="src/",  # bundles train.py + utils; uploaded to the job at runtime
)
estimator.fit({"train": "s3://data/train", "val": "s3://data/val"})

And inference:

from sagemaker.model import Model

model = Model(
    image_uri="111122223333.dkr.ecr.eu-west-1.amazonaws.com/jax-flax-byoc:2027-09",
    model_data=estimator.model_data,
    role=role,
    entry_point="inference.py",
    source_dir="src/",
)
predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.xlarge",
    endpoint_name="jax-flax-endpoint",
)

Hyperparameter tuning works the same way, construct a HyperparameterTuner with the estimator and a search space; SageMaker launches N training jobs in parallel against the BYOC image and picks the best.

A worked migration

The research team’s migration takes roughly a week, broken into:

  • Day 1: Dockerfile and base image. Decide on CUDA 12.3, Ubuntu 22.04, Python 3.11; install JAX, Flax, Optax, toolkits; confirm the image runs training locally with the SageMaker paths mounted as volumes.
  • Day 2: training contract. Refactor the training script to read from /opt/ml/input/... and write to /opt/ml/model. Add the hyperparameters.json read path. Local-mode test with SageMaker SDK’s local instance type.
  • Day 3: ECR push + first cloud training run. Push the image, launch a small training job on one p4de.24xlarge, confirm model artefact ends up in S3.
  • Day 4: inference contract. Implement model_fn/input_fn/predict_fn/output_fn. Test the container locally by running serve and hitting /ping and /invocations.
  • Day 5: endpoint deployment + first inference test. Deploy an endpoint, smoke-test it.
  • Day 6: hyperparameter tuner + CloudWatch wiring. Define tuner, launch a small search, confirm Studio shows results.
  • Day 7: Model Registry, CI for the image build, documentation.

One week to replace “four EC2 instances and a shell loop” with “SageMaker training, tuning, registry, endpoint.” The BYOC image is an investment paid once and amortised over every future run.

What’s worth remembering

  1. SageMaker’s managed services all run containers. Training jobs, Processing jobs, endpoints, batch transform, different contracts, same Docker image shape.
  2. Training contract: files and exit code. Read hyperparameters and channels from fixed paths; write artefacts to /opt/ml/model; exit 0 for success. No HTTP server.
  3. Inference contract: HTTP on :8080. GET /ping returns 200 when ready; POST /invocations accepts the payload and returns inference. Model artefact unpacked at /opt/ml/model.
  4. The toolkits save the plumbing. sagemaker-training-toolkit and sagemaker-inference-toolkit implement the contract shells; you write framework code. Model-fn/input-fn/predict-fn/output-fn is the inference pattern.
  5. ECR is where images live. One-time setup; SageMaker pulls from ECR in the same account and Region as the job.
  6. DLC as a base image is the easy win for supported-framework-plus-dependencies. FROM the DLC; layer extras on top; inherit the CUDA/driver/EFA configuration.
  7. Use local mode to debug before paying for cloud. instance_type="local" in the SDK runs the container on the laptop, same contract, tight iteration loop.
  8. The BYOC image is a durable artefact. Build once, push to ECR, use in training jobs, tuners, endpoints, batch transform, Processing jobs. Versioned via image tag; the team treats it as a first-class deliverable.

BYOC is how you put a framework AWS hasn’t packaged onto SageMaker’s rails. The work is writing to the contract, not reinventing the service, and once the image is built, JAX, Flax, or any other unsupported framework behaves like a first-party one.

These posts are LLM-aided. Backbone, original writing, and structure by Craig. Research and editing by Craig + LLM. Proof-reading by Craig.