June 3, 2025
AI/ML Infrastructure Training Cost Optimization

Training on Spot/Preemptible Instances: Checkpoint and Recovery

You've spent three days tuning hyperparameters. Your model is finally converging. Then AWS reclaims your p3.16xlarge instance with two minutes of notice, and you're watching 48 hours of training progress disappear. Sound familiar?

Spot instances are the dark horse of ML infrastructure - they cost 60-90% less than on-demand equivalents, which is intoxicating when you're scaling. But that savings comes with a price: interruption. The question isn't whether your instances will be terminated; it's when. And how well you've prepared for it.

This article walks you through a battle-tested approach to training on spot and preemptible instances without losing sleep (or training runs). We'll build a fault-tolerant training harness that handles interruptions gracefully, survives topology changes in distributed training), and resumes from the exact checkpoint as if nothing happened.

Table of Contents
  1. The Economics: Why Spot Instances Are Worth the Complexity
  2. Why This Matters in Production
  3. Checkpoint Strategy: The Science of Frequency
  4. Why Async Upload Is Non-Negotiable
  5. Catching the Interruption Signal: Beat AWS to the Punch
  6. How Interruption Signals Work Across Clouds
  7. Elastic Training: Dynamic Topology with torch.distributed.elastic
  8. Scaling Considerations with Elastic Training
  9. Fault-Tolerant Pipeline Storage: The Foundation
  10. Common Pitfalls and How to Avoid Them
  11. Architecture Decisions and Patterns
  12. Putting It Together: The Complete Harness
  13. Diagrams: The Architecture
  14. The Hidden Costs of Spot Infrastructure
  15. Summary

The Economics: Why Spot Instances Are Worth the Complexity

Let's start with cold numbers, because if the math doesn't work, elegance doesn't matter. Understanding whether spot instances make sense for your workload requires honest accounting. It's easy to look at the 70% price difference and get excited. It's harder to account for the hidden costs of building checkpoint infrastructure, handling failures gracefully, and actually recovering from interruptions.

But when you do the math carefully, spot instances almost always win for long-running ML workloads. The key insight is that the cost advantage (70% cheaper) massively outweighs the worst-case interruption penalty (lose at most 15 minutes of work if you checkpoint every 15 minutes). Even with a high interruption rate, you're still ahead.

There's a psychological barrier though. Using on-demand instances feels safe. You pay the full price, but you "know" your job will finish. Using spot instances feels risky. You're gambling that AWS won't reclaim your instance. But this framing is misleading. With spot instances, you're not gambling more risk - you're accepting a different risk profile in exchange for substantial savings. And critically, if you've built proper checkpoint infrastructure, that risk is actually manageable and bounded.

Consider the mental model: on-demand instance costs X per hour, and you need 100 hours, so you pay 100X. Spot instance costs 0.3X per hour, but you expect 4 interruptions, losing 15 minutes each (1 hour total), so you need 101 hours of compute, paying 30.3X. You're paying 70% less for a 1% increase in total compute time. The break-even is at just 3% additional compute overhead, and you're well below that.

But the math only works if your checkpoint infrastructure is actually reliable. If your checkpoints are corrupted, or if resuming from a checkpoint is flaky, or if you lose more work per interruption than you budgeted for, the numbers flip. That's why the second half of this article focuses heavily on robustness. The infrastructure has to be bulletproof.

A p3.16xlarge on-demand in us-east-1 costs roughly $24.48/hour. The same instance as a spot runs about $7.35/hour - a 70% discount. If you're training a model for 100 hours, that's $1,650 saved. Even if you need to rerun 15% of your training due to interruptions, you're still ahead by $900.

But which instances actually get interrupted? AWS doesn't publish exact interruption rates, but we can infer from observed patterns:

  • p3.16xlarge: ~5% hourly interruption probability (roughly 1.2 interruptions per day)
  • p3.2xlarge: ~3-4% hourly interruption
  • p4d.24xlarge: ~15% hourly interruption (newer, less capacity)
  • g4dn.12xlarge: ~8% hourly interruption

The trap is assuming your 100-hour training will be interrupted once and cost you 5 extra hours. In reality, if you checkpoint every 30 minutes, you lose at most 30 minutes of work per interruption. Over 4 interruptions in a 100-hour run, that's 2 hours of rework - still a massive win.

The break-even happens fast. Checkpoint overhead (S3 upload, disk I/O) typically adds 2-5% to training time. If saving a 50GB checkpoint takes 3 minutes and you do it every 30 minutes, that's 6 minutes overhead per hour, or 10% of your 60-minute checkpoint interval. But you only save that time if you actually get interrupted, and the cost reduction is 70%. Even with four interruptions, you're winning.

The real insight: Spot instances aren't for one-off 10-hour training runs. They're for long-running experiments, hyperparameter sweeps, and production retraining pipelines where you expect multiple interruptions and have the infrastructure to handle them.

Why This Matters in Production

In a real-world ML team, spot instances aren't just a cost optimization - they're a forcing function for building robust infrastructure. When you design for interruption, you're forced to confront questions that plague distributed training-ddp-advanced-distributed-training): How do you coordinate state across nodes? What happens when the network partitions? How do you verify a checkpoint is valid before resuming?

These questions apply equally to on-demand training. By building for spots, you're building for resilience. A model that can survive interruptions can survive a kernel panic, a network blip, or a GPU memory corruption. The discipline pays dividends beyond cost savings.

Teams that practice spot-instance training also tend to have better observability. They track checkpoint frequency, recovery time, and interruption patterns. They can answer questions like "How much compute are we actually wasting to interruptions?" and "Which instance types are unreliable in our region?" This data drives smarter infrastructure decisions downstream.

Checkpoint Strategy: The Science of Frequency

Here's what most people get wrong: they checkpoint too often (killing throughput) or too infrequently (losing too much work per interruption). Finding the right frequency is part art, part science, but there's a framework that works well in practice.

Understanding the practical constraints of checkpoint management requires appreciating how the different components of your system interact. When you initiate a checkpoint, your training process must serialize the entire model state, including all learnable parameters, batch normalization statistics, exponential moving averages if you're using optimizers like Adam, and any other stateful components your architecture maintains. This serialization isn't instantaneous. On a modern GPU with NVMe storage, transferring 50GB of model state to disk might take three to five minutes. During those minutes, your GPU is completely unutilized. You're paying for compute capacity that's sitting idle. This is the fundamental tension at the heart of checkpoint optimization: more frequent checkpoints mean better protection against interruption losses, but they also mean more total wasted compute time due to serialization overhead.

Consider what happens across a full training run when you vary checkpoint frequency. If you train for 100 hours and checkpoint every two hours, you have 50 checkpoints. Each checkpoint takes four minutes, so that's 200 minutes, or about 3.3 hours of idle GPU time. You've sacrificed 3.3 percent of your available compute to checkpointing. But if you get interrupted, you lose at most two hours of work. If you checkpoint every 30 minutes instead, you have 200 checkpoints, consuming 13.3 hours of compute to checkpointing. You've cut your worst-case loss to 30 minutes, but you're spending 13 percent of your compute time saving state. The math only works out if interruptions are frequent enough to justify the overhead. For the 5 percent hourly interruption rate typical of p3 instances, the every-two-hours strategy is better. For the 15 percent rate of newer p4d instances, the more-frequent strategy wins.

The additional complexity emerges when you consider distributed training-zero-memory-efficient-training)-comparison)-zero-memory-efficient-training). If you're training across eight GPUs, checkpoint synchronization becomes intricate. All eight GPUs must coordinate to write a consistent snapshot of the global model state. If one GPU finishes serialization while others are still writing, you've created a consistency problem. Some implementations use barriers - all GPUs wait for all others to complete - which adds latency on top of the serialization time itself.

The intuition is simple: every checkpoint costs time. You're saving model state, optimizer state, and random number generator state to disk. That's not free. The larger your model, the longer the checkpoint takes. Large language models can have checkpoints that take 5-30 minutes. During that time, your training loop is completely blocked. Your GPU sits idle, cooling fans whirring, but no gradient steps are happening. That's wasted compute.

On the flip side, if you checkpoint infrequently, say once per day, and you get interrupted right after a checkpoint, you lose at most one day of compute. That's catastrophic. A week-long training run becomes two weeks. A month becomes two months. The cost savings from spot instances evaporate.

The magic happens in the middle. If you checkpoint frequently enough that any interruption loses at most 15 minutes of work, the worst-case scenario is bounded. Yes, if you get interrupted right after a checkpoint, you lose 15 minutes. But your job was already trained for 100 hours, so losing 15 more minutes is a rounding error. You still come out ahead with spot economics.

The challenge is that "15 minutes of lost work" varies depending on your throughput. If you're doing 100 gradient steps per minute, 15 minutes is 1500 steps. If you're doing 1 step per minute (perhaps because you're working with very large models), 15 minutes is just 15 steps. For some workloads, that's enormous loss; for others, it's negligible.

The framework handles this by saying: checkpoint every N iterations where N is calculated to represent 15 minutes of typical throughput. This way, the worst-case loss is consistent regardless of how fast your training runs. You're encoding a time-based strategy (15 minutes worst-case) into an iteration-based checkpoint interval (N iterations).

The optimal frequency is: checkpoint every N steps such that worst-case rework (one interruption right after a checkpoint) caps at 15 minutes of lost training.

If your training runs 100 iterations per minute, that's 1,500 iterations per 15-minute window. Checkpoint every 1,500 iterations.

Here's the framework:

python
# Calculate checkpoint interval
def get_checkpoint_interval(
    throughput_iter_per_min: float,
    max_loss_minutes: float = 15.0,
    interruption_probability_per_hour: float = 0.05
) -> int:
    """
    Returns: iterations between checkpoints
 
    Reasoning: In worst case, interruption happens right after checkpoint,
    and we lose (throughput * max_loss_minutes) iterations of work.
    """
    iterations_per_hour = throughput_iter_per_min * 60
    return int(throughput_iter_per_min * max_loss_minutes)
 
# Example: 100 iter/min throughput
interval = get_checkpoint_interval(100)  # Returns 1500 iterations

But there's a secondary constraint: storage cost. A 50GB checkpoint saved 1,500 times in a 100-hour run costs roughly $0.50 in S3 (assuming 1 GB-month = $0.023 and 75GB average stored). That's noise compared to compute savings, but it matters at scale with very frequent checkpoints.

The practical strategy:

  1. Checkpoint every N steps (where N = 15 min of compute worst-case)
  2. Keep only the last 2-3 checkpoints (older ones are replaced; deletes are free in S3 under versioning with lifecycle)
  3. Use async upload (never block training for I/O)
  4. Partition by experiment (separate S3 prefixes prevent accidental overwrites)

Why Async Upload Is Non-Negotiable

When you save a 50GB checkpoint synchronously, your training loop blocks for 3-10 seconds waiting for disk I/O to complete. That's dead compute time. The GPU sits idle. With distributed training across 8 nodes, that's 8 nodes × $X/hour all burning compute cost on nothing.

By uploading asynchronously, you save locally (fast, microseconds to milliseconds), then queue the S3 upload to happen in the background. Your training loop never pauses. The S3 upload might fail - the network drops, the bucket is misconfigured - but your checkpoint is safe on the local EBS volume. You can retry the upload later or fetch it manually if the instance is terminated.

This asymmetry - fast local save, slow async remote save - is the key to real-world checkpoint reliability.

Here's a minimal checkpoint manager:

python
import torch
import asyncio
import os
from pathlib import Path
from datetime import datetime
 
class CheckpointManager:
    def __init__(
        self,
        checkpoint_dir: str,
        s3_bucket: str,
        s3_prefix: str,
        keep_last_n: int = 3
    ):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
 
        self.s3_bucket = s3_bucket
        self.s3_prefix = s3_prefix
        self.keep_last_n = keep_last_n
        self.local_checkpoints = []
 
    def save_checkpoint(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        epoch: int,
        global_step: int,
        rng_state: dict = None
    ):
        """Save checkpoint synchronously (fast local disk), upload async."""
        timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        ckpt_name = f"ckpt_epoch{epoch}_step{global_step}_{timestamp}.pt"
        ckpt_path = self.checkpoint_dir / ckpt_name
 
        # Local save (fast, blocking)
        checkpoint = {
            'epoch': epoch,
            'global_step': global_step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'rng_state': rng_state or torch.get_rng_state(),
        }
        torch.save(checkpoint, ckpt_path)
 
        # Track locally
        self.local_checkpoints.append(ckpt_path)
 
        # Async upload to S3 (don't block training)
        asyncio.create_task(
            self._upload_to_s3(ckpt_path, ckpt_name)
        )
 
        # Clean up old local checkpoints
        self._cleanup_old_local(self.keep_last_n)
 
        return str(ckpt_path)
 
    async def _upload_to_s3(self, local_path: Path, remote_name: str):
        """Non-blocking S3 upload."""
        try:
            import boto3
            s3 = boto3.client('s3')
            s3_key = f"{self.s3_prefix}/{remote_name}"
 
            # Use multipart upload for large files
            s3.upload_file(
                str(local_path),
                self.s3_bucket,
                s3_key,
                Config=boto3.s3.transfer.TransferConfig(
                    multipart_threshold=1024 * 25,  # 25MB
                    max_concurrency=8
                )
            )
            print(f"[CHECKPOINT] Uploaded to s3://{self.s3_bucket}/{s3_key}")
        except Exception as e:
            print(f"[CHECKPOINT] S3 upload failed: {e}. Checkpoint still on disk.")
 
    def _cleanup_old_local(self, keep_n: int):
        """Remove oldest local checkpoints, keep only N recent."""
        if len(self.local_checkpoints) > keep_n:
            to_delete = self.local_checkpoints[:-keep_n]
            for path in to_delete:
                try:
                    path.unlink()
                except FileNotFoundError:
                    pass
            self.local_checkpoints = self.local_checkpoints[-keep_n:]

Critical detail: The async upload means your checkpoint is on disk immediately (safe from instance loss), but S3 gets it in the background. If the instance dies before upload completes, you've still got a valid checkpoint locally - you just need to fetch it back from the instance's EBS volume if you want to preserve it. In practice, this is rare because interruption warnings give you ~120 seconds to react.

Catching the Interruption Signal: Beat AWS to the Punch

AWS gives you two minutes of notice before terminating a spot instance. GCP gives you 30 seconds for preemptible instances. That's your window to save state and exit gracefully.

Both clouds expose this via instance metadata:

AWS (IMDSv2):

python
import requests
import time
 
def get_spot_interruption_warning():
    """
    AWS exposes spot termination notice at:
    http://169.254.169.254/latest/meta-data/spot/instance-action
 
    Returns: dict with action, time, etc., or None if no warning
    """
    token_url = "http://169.254.169.254/latest/api/token"
    metadata_url = "http://169.254.169.254/latest/meta-data/spot/instance-action"
 
    try:
        # Get IMDSv2 token (required for security)
        token_response = requests.put(
            token_url,
            headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
            timeout=1
        )
        token = token_response.text
 
        # Fetch interruption notice
        response = requests.get(
            metadata_url,
            headers={"X-aws-ec2-metadata-token": token},
            timeout=1
        )
 
        if response.status_code == 200:
            import json
            return json.loads(response.text)
        return None
    except requests.RequestException:
        return None

GCP (Preemption Signal):

python
import signal
import threading
 
def handle_preemption_signal(signum, frame):
    """GCP sends SIGTERM 30 seconds before preemption."""
    print("[INTERRUPTION] Preemption signal received. Triggering emergency checkpoint...")
    # Signal main training loop to save and exit
    global should_exit
    should_exit = True
 
signal.signal(signal.SIGTERM, handle_preemption_signal)

How Interruption Signals Work Across Clouds

Understanding the mechanics helps you design better recovery. AWS sends an HTTP 200 response to a well-known metadata endpoint when a spot interruption is coming. GCP sends a UNIX signal (SIGTERM). Both are fallible - networks can be partitioned, signals can be missed if the handler isn't registered correctly.

The key insight: you can't rely on a single mechanism. You need defense in depth. Monitor the metadata endpoint in a background thread. Register signal handlers. Implement a timeout - if training takes longer than expected on a known-interrupted instance type, assume it might terminate soon and save preemptively. The goal is to catch the majority of interruptions and fail gracefully when you can't.

Now the critical part: integrating this into your training loop. You need a background thread constantly polling for interruption signals, and when it detects one, it must communicate with the main training loop to stop and save.

Here's a production-grade approach:

python
import threading
import time
from typing import Callable
 
class InterruptionHandler:
    def __init__(self, checkpoint_callback: Callable, cloud: str = "aws"):
        self.checkpoint_callback = checkpoint_callback
        self.cloud = cloud
        self.should_exit = False
        self.monitor_thread = None
 
    def start(self):
        """Start background monitoring."""
        self.monitor_thread = threading.Thread(
            target=self._monitor_loop,
            daemon=True
        )
        self.monitor_thread.start()
 
    def _monitor_loop(self):
        """Periodically check for interruption signals."""
        check_interval = 5  # seconds
 
        while not self.should_exit:
            if self.cloud == "aws":
                warning = self._get_aws_interruption_warning()
                if warning:
                    self._handle_interruption(warning)
            elif self.cloud == "gcp":
                # GCP uses SIGTERM signal; monitor via flag set by handler
                if self.should_exit:
                    self._handle_interruption({"source": "gcp"})
 
            time.sleep(check_interval)
 
    def _get_aws_interruption_warning(self):
        """Fetch AWS spot interruption notice."""
        import requests
        try:
            token_response = requests.put(
                "http://169.254.169.254/latest/api/token",
                headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
                timeout=1
            )
            token = token_response.text
 
            response = requests.get(
                "http://169.254.169.254/latest/meta-data/spot/instance-action",
                headers={"X-aws-ec2-metadata-token": token},
                timeout=1
            )
 
            if response.status_code == 200:
                import json
                return json.loads(response.text)
        except:
            pass
        return None
 
    def _handle_interruption(self, warning: dict):
        """Trigger emergency checkpoint and graceful shutdown."""
        print(f"[INTERRUPTION] Instance termination detected!")
        print(f"[INTERRUPTION] Warning: {warning}")
 
        # Callback to save state
        self.checkpoint_callback()
 
        # Signal to exit after checkpoint completes
        self.should_exit = True
 
        # Give training loop time to notice and shut down gracefully
        time.sleep(10)
 
        # If still running, force exit (instance is about to die anyway)
        import sys
        sys.exit(0)

Usage in your training loop:

python
def main():
    model = MyModel()
    optimizer = torch.optim.AdamW(model.parameters())
    checkpoint_mgr = CheckpointManager(...)
 
    # Setup interruption handler
    def emergency_checkpoint():
        print("[EMERGENCY CHECKPOINT] Saving state...")
        checkpoint_mgr.save_checkpoint(
            model, optimizer, current_epoch, global_step
        )
 
    handler = InterruptionHandler(emergency_checkpoint, cloud="aws")
    handler.start()
 
    # Training loop
    for epoch in range(num_epochs):
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            # Check if we've been interrupted
            if handler.should_exit:
                print("[TRAINING] Graceful exit triggered.")
                return
 
            # Normal training step
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            global_step += 1
 
            # Periodic checkpoint (not emergency)
            if global_step % checkpoint_interval == 0:
                checkpoint_mgr.save_checkpoint(
                    model, optimizer, epoch, global_step
                )

The flow is simple: background thread notices termination warning → calls your callback to save → sets a flag → training loop checks flag and exits cleanly. By the time AWS actually kills the instance (120 seconds later), you're already safe.

Elastic Training: Dynamic Topology with torch.distributed.elastic

Here's where it gets sophisticated. When a spot instance dies mid-training in a distributed setup, the entire job crashes because the communication topology is broken. You can't just resume the checkpoint on the same nodes - new nodes might be assigned, and the old ones might be gone. Distributed training is inherently fragile when any single node can vanish without warning.

This fragility is the real challenge with distributed spot training. In single-node training, spot interruption is annoying but manageable. You get a warning, you save a checkpoint, you resume on a replacement node. The checkpoint has all the state you need. In distributed training, it's vastly more complex.

Imagine 8 nodes doing distributed training via NCCL (NVIDIA's collective communication library). Each node has a rank (0-7). They synchronize gradients at each step. Node 5 suddenly terminates. NCCL tries to communicate with all 8 nodes; node 5 doesn't respond. The collective fails. All nodes error out. Your entire training job crashes, even though nodes 0-4 and 6-7 are still healthy.

You can't just restart the job on the surviving nodes because the checkpoint was saved assuming rank 0-7 existed. If you resume with rank 0-6 (only 7 nodes), your distributed sampler thinks it needs to shard data across 7 nodes, but the checkpoint has sharded state for 8 nodes. The learning rate was calculated for a batch size of 32×8=256. Now it's 32×7=224. All these mismatches cause subtle bugs.

That's where torch.distributed.elastic comes in. It's a framework for distributed training that expects node failures and handles them gracefully. Instead of hardcoding 8 nodes upfront, you say "run with 2 to 4 nodes minimum, up to 8 nodes maximum." If a node dies, the training pauses. A rendezvous service (think of it as a coordination server) notices the node is missing. A replacement node comes online. The rendezvous waits until the new node checks in, the world reconvenes, and training resumes with the new topology.

The beauty of elastic training is that it makes distributed spot training practical. The overhead is real - you're pausing training to handle topology changes - but it's manageable. And for the cost savings of spot instances, it's worthwhile.

PyTorch's torch.distributed.elastic (invoked via torchrun) solves this with dynamic membership and rendezvous. When a node joins or leaves, the training loop pauses, the membership updates, and training resumes on the new topology.

The rendezvous mechanism is a distributed coordination service (AWS Elastic Container Service, Kubernetes, or a file-based backend) that:

  1. Tracks which nodes are currently alive
  2. Waits for min_nodes to be available before starting
  3. Pauses training when membership changes
  4. Reloads the checkpoint with the new world size
  5. Resumes training

Scaling Considerations with Elastic Training

When you enable elastic training, you're introducing new complexity: what happens to your learning rate when you lose a node? With distributed training, the effective batch size is batch_size × world_size. If you start with 4 nodes (batch 32, world 4, effective 128) and drop to 3 nodes, your effective batch size becomes 96. Smaller batches mean noisier gradient estimates, which can destabilize training.

The standard solution is learning rate scaling: lr_new = lr_old × (world_size_new / world_size_old). This keeps the effective learning rate constant relative to batch size. Empirically, this works well for most optimizers (Adam, AdamW, SGD with momentum).

Another consideration: data parallelism with a shrinking world size. Your distributed sampler needs to be aware of the new world size. If you're using DistributedSampler, it requires num_replicas and rank, which change when topology shifts. You need to reload the sampler state or at least skip to the right batch index in the new epoch.

Here's how to set it up:

Training script (train_elastic.py):

python
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing import api
from torch.distributed.elastic.agent.server.api import RunResult
 
def train(rank: int, world_size: int, checkpoint_mgr):
    """Training function called by torchrun on each node."""
 
    # Initialize distributed group
    dist.init_process_group(
        backend="nccl",
        rank=rank,
        world_size=world_size
    )
 
    model = MyModel().to(rank)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model)
    optimizer = torch.optim.AdamW(ddp_model.parameters())
 
    # Load checkpoint if resuming
    checkpoint_path = checkpoint_mgr.get_latest_checkpoint()
    if checkpoint_path:
        print(f"[ELASTIC] Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        ddp_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['global_step']
    else:
        start_epoch = 0
        start_step = 0
 
    global_step = start_step
    checkpoint_interval = 1500  # from earlier calculation
 
    for epoch in range(start_epoch, num_epochs):
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            seed=epoch
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            sampler=sampler,
            batch_size=batch_size
        )
 
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            if batch_idx < (start_step % len(dataloader)):
                continue  # Skip already-processed batches
 
            outputs = ddp_model(inputs)
            loss = criterion(outputs, labels)
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            global_step += 1
 
            if global_step % checkpoint_interval == 0 and rank == 0:
                checkpoint_mgr.save_checkpoint(
                    ddp_model, optimizer, epoch, global_step
                )
 
        # Note: epoch sampler reset happens automatically
 
    dist.destroy_process_group()
 
if __name__ == "__main__":
    # This is called by torchrun
    import torch.multiprocessing as mp
 
    checkpoint_mgr = CheckpointManager(...)
 
    # torchrun sets environment variables
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
 
    train(rank, world_size, checkpoint_mgr)

Launch script (launch.sh):

bash
#!/bin/bash
 
# torchrun handles node management, restarts workers automatically
torchrun \
  --nnodes=2:4 \
  --nproc_per_node=8 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=<your-rendezvous-endpoint> \
  --rdzv_id=training_run_001 \
  train_elastic.py

The key parameters:

  • --nnodes=2:4: minimum 2 nodes, maximum 4 nodes (elastic!)
  • --nproc_per_node=8: 8 GPUs per node
  • --rdzv_backend=c10d: use PyTorch's distributed store (works in Kubernetes, ECS, or with a static IP)
  • --rdzv_endpoint: where the rendezvous server is listening (or a Kubernetes service name)
  • --rdzv_id: unique ID for this training job (allows multiple concurrent runs)

When a spot instance is reclaimed:

  1. torchrun detects the worker process died
  2. The rendezvous mechanism notices the rank is missing
  3. Training pauses on remaining workers
  4. A new worker is launched on a replacement instance
  5. Rendezvous waits for the new worker to check in
  6. Training resumes on the new topology

The catch: The new topology might have a different world size. If you started with 4 nodes (32 GPUs) and one dies, you're now at 3 nodes (24 GPUs). Your DDP model and optimizer state need to be adjusted.

Here's a robust checkpoint loader:

python
def load_checkpoint_elastic(
    checkpoint_path: str,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    old_world_size: int,
    new_world_size: int
) -> tuple:
    """Load checkpoint, handle world size change."""
 
    checkpoint = torch.load(checkpoint_path)
 
    # Model state dict is usually agnostic to world size
    # (DDP wraps the same model regardless of GPUs)
    model.load_state_dict(checkpoint['model_state_dict'])
 
    # Optimizer state dict needs care
    # If world size changed, scale learning rate proportionally
    if old_world_size != new_world_size:
        print(f"[ELASTIC] World size changed: {old_world_size} -> {new_world_size}")
 
        # Scale learning rate by world size ratio
        # (gradient accumulation scales; adapt LR accordingly)
        scale_factor = new_world_size / old_world_size
 
        for param_group in optimizer.param_groups:
            param_group['lr'] *= scale_factor
 
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 
    return checkpoint['epoch'], checkpoint['global_step']

This is where most people stumble: they assume a checkpoint from 32 GPUs will work identically on 24 GPUs. It won't. The gradient accumulation dynamics are different. The solution is proportional learning rate scaling, which we've handled above.

Fault-Tolerant Pipeline Storage: The Foundation

Your checkpoint is only as good as your ability to restore from it. The pipeline needs to handle partial restoration - what if the model state loads but the optimizer state is corrupted?

Design your checkpoint format for resilience:

python
import json
 
def save_checkpoint_resilient(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    global_step: int,
    dataloader_state: dict,
    rng_state_cpu: torch.Tensor,
    rng_state_gpu: torch.Tensor,
    checkpoint_path: str
):
    """
    Save checkpoint with independent component saving.
    If one component corrupts, others are still valid.
    """
 
    # Save each component to a separate file
    components = {
        'model': {
            'path': checkpoint_path.replace('.pt', '_model.pt'),
            'data': model.state_dict()
        },
        'optimizer': {
            'path': checkpoint_path.replace('.pt', '_optimizer.pt'),
            'data': optimizer.state_dict()
        },
        'rng': {
            'path': checkpoint_path.replace('.pt', '_rng.pt'),
            'data': {
                'cpu': rng_state_cpu,
                'gpu': rng_state_gpu,
            }
        },
        'metadata': {
            'path': checkpoint_path.replace('.pt', '_metadata.json'),
            'data': {
                'epoch': epoch,
                'global_step': global_step,
                'dataloader_state': dataloader_state,
                'timestamp': datetime.utcnow().isoformat(),
            }
        }
    }
 
    for name, component in components.items():
        try:
            if name == 'metadata':
                with open(component['path'], 'w') as f:
                    json.dump(component['data'], f)
            else:
                torch.save(component['data'], component['path'])
            print(f"[CHECKPOINT] Saved {name}")
        except Exception as e:
            print(f"[CHECKPOINT] Failed to save {name}: {e}")
            # Don't fail the entire checkpoint; continue with others
 
def load_checkpoint_resilient(checkpoint_path: str, model, optimizer):
    """
    Load checkpoint components with graceful fallback.
    If model fails, don't load anything.
    If optimizer fails, continue with just the model.
    """
 
    epoch = 0
    global_step = 0
 
    # Load metadata first (lightweight)
    meta_path = checkpoint_path.replace('.pt', '_metadata.json')
    try:
        with open(meta_path, 'r') as f:
            metadata = json.load(f)
            epoch = metadata.get('epoch', 0)
            global_step = metadata.get('global_step', 0)
    except Exception as e:
        print(f"[CHECKPOINT] Warning: Could not load metadata: {e}")
 
    # Load model (critical)
    model_path = checkpoint_path.replace('.pt', '_model.pt')
    try:
        model_state = torch.load(model_path)
        model.load_state_dict(model_state)
        print(f"[CHECKPOINT] Loaded model state")
    except Exception as e:
        print(f"[CHECKPOINT] CRITICAL: Could not load model: {e}")
        raise
 
    # Load optimizer (optional; can retrain without perfect optimizer state)
    optimizer_path = checkpoint_path.replace('.pt', '_optimizer.pt')
    try:
        optimizer_state = torch.load(optimizer_path)
        optimizer.load_state_dict(optimizer_state)
        print(f"[CHECKPOINT] Loaded optimizer state")
    except Exception as e:
        print(f"[CHECKPOINT] Warning: Could not load optimizer: {e}")
        print(f"[CHECKPOINT] Continuing with fresh optimizer state. Training will resume but convergence may be slower.")
 
    # Load RNG state (for reproducibility)
    rng_path = checkpoint_path.replace('.pt', '_rng.pt')
    try:
        rng_state = torch.load(rng_path)
        torch.set_rng_state(rng_state['cpu'])
        torch.cuda.set_rng_state(rng_state['gpu'])
        print(f"[CHECKPOINT] Restored RNG state")
    except Exception as e:
        print(f"[CHECKPOINT] Warning: Could not load RNG state: {e}")
        print(f"[CHECKPOINT] RNG will be reinitialized. Batch order may differ from original run.")
 
    return epoch, global_step

Why separate files? YAML, JSON, and pickle can all corrupt. By separating components, a single corruption doesn't nuke the entire checkpoint. You can restore the model, skip the optimizer (retrain), and move on. This is exactly what happens in production when you're desperate to recover.

Common Pitfalls and How to Avoid Them

Training on spots introduces failure modes that don't exist with on-demand instances. Understanding these pitfalls helps you design more robust systems.

1. Forgetting to handle world size changes: When nodes join/leave, your model and optimizer expect different gradient shapes. Always scale learning rate and validate shapes on resume.

2. Checkpoint frequency too high: Checkpointing every step kills throughput. Use the 15-minute worst-case formula. Every 1-5 minutes is typically right.

3. Blocking the training loop on I/O: Always async upload. Local save fast, S3 upload in background.

4. Not testing the interruption path: Your handler only works if you test it. Manually terminate an instance during training and verify resumption works. This is non-negotiable.

5. Losing data loader state: When you resume, the dataloader doesn't know which batch it was on. You need to track the global step and skip already-processed batches (or use SeededDataLoader with a checkpoint of current position).

6. Assuming IMDSv2 is always available: Some VPCs have it disabled. Add try/except around metadata requests.

Architecture Decisions and Patterns

As you build out spot-instance training at scale, architectural decisions compound. Here are the major ones:

Checkpoint placement: Local disk, shared EBS, S3? Local is fastest for small checkpoints (<10GB). Shared EBS works for multi-node with limited fault tolerance. S3 is mandatory for cloud-resilient training but slower. Hybrid (local + async S3) is the production standard.

State synchronization: With multiple nodes, who decides when to checkpoint? Single-node decision (rank 0 only) is simplest but creates bottlenecks. Distributed consensus (all nodes vote) is robust but adds complexity. The sweet spot: rank 0 checks interruption signals and broadcasts the decision to others.

Resumption logic: When resuming, should you start exactly where you left off, or skip a few steps for safety? Some systems skip 10-100 steps to avoid corruption from partial writes. This adds robustness at the cost of minimal compute overhead.

Putting It Together: The Complete Harness

Here's a minimal but production-ready training harness combining everything:

python
#!/usr/bin/env python3
"""
Fault-tolerant training harness for spot/preemptible instances.
Supports distributed training with dynamic topology.
 
Usage:
  Single-node:     python train_harness.py
  Multi-node:      torchrun --nnodes=2:4 --nproc_per_node=8 train_harness.py
"""
 
import os
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from datetime import datetime
import threading
import time
import json
from pathlib import Path
 
class CheckpointManager:
    def __init__(self, checkpoint_dir, s3_bucket=None, s3_prefix=None):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.s3_bucket = s3_bucket
        self.s3_prefix = s3_prefix
 
    def get_latest_checkpoint(self):
        """Find most recent checkpoint."""
        checkpoints = list(self.checkpoint_dir.glob("ckpt_*_metadata.json"))
        if not checkpoints:
            return None
        # Most recently modified
        return str(sorted(checkpoints, key=lambda p: p.stat().st_mtime)[-1]).replace('_metadata.json', '')
 
    def save_checkpoint(self, model, optimizer, epoch, global_step, rank=0):
        """Save checkpoint (main process only)."""
        if rank != 0:
            return
 
        timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        ckpt_base = self.checkpoint_dir / f"ckpt_e{epoch}_s{global_step}_{timestamp}"
 
        # Model
        torch.save(model.state_dict(), str(ckpt_base) + "_model.pt")
 
        # Optimizer
        torch.save(optimizer.state_dict(), str(ckpt_base) + "_optimizer.pt")
 
        # RNG
        torch.save({
            'cpu': torch.get_rng_state(),
            'gpu': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
        }, str(ckpt_base) + "_rng.pt")
 
        # Metadata
        with open(str(ckpt_base) + "_metadata.json", 'w') as f:
            json.dump({
                'epoch': epoch,
                'global_step': global_step,
                'timestamp': datetime.utcnow().isoformat()
            }, f)
 
        print(f"[CHECKPOINT] Saved at {ckpt_base}")
 
    def load_checkpoint(self, model, optimizer):
        """Load latest checkpoint."""
        ckpt_path = self.get_latest_checkpoint()
        if not ckpt_path:
            print("[CHECKPOINT] No existing checkpoint found. Starting fresh.")
            return 0, 0
 
        print(f"[CHECKPOINT] Loading {ckpt_path}")
 
        epoch, global_step = 0, 0
 
        try:
            with open(ckpt_path + "_metadata.json", 'r') as f:
                meta = json.load(f)
                epoch = meta['epoch']
                global_step = meta['global_step']
        except:
            print("[CHECKPOINT] Could not load metadata.")
 
        try:
            model.load_state_dict(torch.load(ckpt_path + "_model.pt"))
        except Exception as e:
            print(f"[CHECKPOINT] FATAL: Could not load model: {e}")
            raise
 
        try:
            optimizer.load_state_dict(torch.load(ckpt_path + "_optimizer.pt"))
        except Exception as e:
            print(f"[CHECKPOINT] WARNING: Could not load optimizer: {e}")
 
        return epoch, global_step
 
class InterruptionHandler:
    def __init__(self, checkpoint_fn):
        self.checkpoint_fn = checkpoint_fn
        self.should_exit = False
 
    def start(self):
        """Start background interruption monitor."""
        thread = threading.Thread(target=self._monitor, daemon=True)
        thread.start()
 
    def _monitor(self):
        """Check for AWS spot interruption warning every 5 seconds."""
        while not self.should_exit:
            try:
                import requests
                token = requests.put(
                    "http://169.254.169.254/latest/api/token",
                    headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
                    timeout=1
                ).text
                response = requests.get(
                    "http://169.254.169.254/latest/meta-data/spot/instance-action",
                    headers={"X-aws-ec2-metadata-token": token},
                    timeout=1
                )
                if response.status_code == 200:
                    print("[INTERRUPTION] Spot termination warning received!")
                    self.checkpoint_fn()
                    self.should_exit = True
                    sys.exit(0)
            except:
                pass
            time.sleep(5)
 
def train_epoch(model, train_loader, optimizer, epoch, checkpoint_mgr, handler, checkpoint_interval=1500):
    """Single training epoch."""
    global_step = 0
 
    for batch_idx, (data, target) in enumerate(train_loader):
        if handler.should_exit:
            return global_step, True  # Signal to exit
 
        data = data.to("cuda")
        target = target.to("cuda")
 
        optimizer.zero_grad()
        output = model(data)
        loss = nn.functional.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
 
        global_step += 1
 
        if global_step % checkpoint_interval == 0:
            checkpoint_mgr.save_checkpoint(model, optimizer, epoch, global_step, rank=torch.distributed.get_rank() if dist.is_initialized() else 0)
 
        if batch_idx % 100 == 0:
            print(f"[TRAIN] Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
 
    return global_step, False
 
def main():
    # Initialize distributed if running under torchrun
    if "RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
 
    if rank == 0:
        print(f"[MAIN] Rank {rank}/{world_size}")
 
    # Setup
    torch.cuda.set_device(rank % torch.cuda.device_count())
 
    model = nn.Linear(10, 5)  # Dummy model
    model = model.to("cuda")
    if world_size > 1:
        model = nn.parallel.DistributedDataParallel(model)
 
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    checkpoint_mgr = CheckpointManager("./checkpoints")
 
    # Load checkpoint if available
    start_epoch, _ = checkpoint_mgr.load_checkpoint(model, optimizer)
 
    # Setup interruption handler
    def emergency_save():
        if rank == 0:
            print("[EMERGENCY] Saving checkpoint...")
            checkpoint_mgr.save_checkpoint(model, optimizer, start_epoch, 0, rank)
 
    handler = InterruptionHandler(emergency_save)
    handler.start()
 
    # Dummy training loop
    for epoch in range(start_epoch, 10):
        if rank == 0:
            print(f"[TRAIN] Starting epoch {epoch}")
 
        # Create dataset and loader
        dataset = torch.randn(1000, 10), torch.randint(0, 5, (1000,))
        if world_size > 1:
            sampler = DistributedSampler(range(1000), num_replicas=world_size, rank=rank)
            loader = DataLoader(range(1000), batch_size=32, sampler=sampler)
        else:
            loader = DataLoader(range(1000), batch_size=32, shuffle=True)
 
        global_step, should_exit = train_epoch(model, loader, optimizer, epoch, checkpoint_mgr, handler)
 
        if should_exit:
            break
 
        if rank == 0:
            checkpoint_mgr.save_checkpoint(model, optimizer, epoch, global_step, rank)
 
    if world_size > 1:
        dist.destroy_process_group()
 
if __name__ == "__main__":
    main()

Diagrams: The Architecture

Figure 1: Training Checkpoint and Recovery Flow

graph TD
    A["Normal Training"] -->|Every N steps| B["Checkpoint Save"]
    B -->|Async upload| C["S3 Storage"]
    B -->|Local disk| D["Instance EBS"]
 
    A -->|Background thread| E["Interruption Monitor"]
    E -->|AWS/GCP signal| F["Interruption Detected"]
 
    F -->|2 min warning| G["Emergency Checkpoint"]
    G -->|Async upload| C
    G -->|Immediate save| D
 
    F -->|After timeout| H["Instance Terminated"]
    H -->|New instance launches| I["Spot Replacement"]
 
    I -->|torchrun detects death| J["Rendezvous Updates"]
    J -->|Load checkpoint| K["Restore Model State"]
    K -->|Resume training| A
 
    C -->|If EBS lost| L["Fetch from S3"]
    L -->|Restore to new instance| K

Figure 2: Elastic Training Topology with Dynamic Nodes

graph TB
    subgraph Initial["Initial: 4 Nodes (32 GPUs)"]
        N1["Node 1: 8 GPUs"]
        N2["Node 2: 8 GPUs"]
        N3["Node 3: 8 GPUs"]
        N4["Node 4: 8 GPUs"]
    end
 
    subgraph RDZ["Rendezvous Coordinator"]
        RDZ1["Tracks: 4 nodes<br/>World size: 4"]
    end
 
    Initial --> RDZ
 
    RDZ -->|Broadcasting gradients| Sync1["Sync Step"]
 
    Sync1 -->|Node 3 interrupted| Interrupt["Interruption"]
    Interrupt -->|Terminate N3| Remove["Remove from Rank"]
 
    Remove --> RDZ2["Rendezvous Updated<br/>3 nodes"]
 
    RDZ2 -->|New instance| Add["Node 5: 8 GPUs added"]
    Add -->|Wait for rendezvous| RDZ3["Rendezvous Updated<br/>4 nodes (different set)"]
 
    RDZ3 -->|Load checkpoint<br/>Scale LR| Restore["Restore Training State"]
    Restore -->|Resume with new topology| Resume["Resume Gradient Sync"]

The Hidden Costs of Spot Infrastructure

While the compute savings from spot instances are real and substantial, the infrastructure costs warrant honest accounting. Building reliable checkpoint systems requires careful design choices that impose their own overhead. Asynchronous uploads to S3 reduce training blocking, but they introduce complexity around buffer management. You need to ensure that checkpoints complete before the instance receives a termination signal. If a checkpoint is still uploading when interruption happens, that data might be lost or corrupted.

The monitoring layer adds another cost dimension. Every training script needs to watch for interruption signals (the AWS 2-minute warning, the GCP 30-second warning, the generic ACPI shutdown from other providers). This monitoring has to be rock solid because a missed signal means unrecoverable loss. Some teams implement multiple layers of monitoring: one at the application level watching AWS EC2 metadata, another at the system level watching ACPI events, a third via the orchestration platform watching instance lifecycle. This redundancy adds operational complexity but prevents the catastrophic case of missing an interruption signal entirely.

Recovery verification is another hidden cost. When you restore from a checkpoint, you need to verify that the checkpoint is actually valid. If a previous upload was corrupted, you don't discover it until you're already resuming training. The recovery takes 30 minutes, you run for an hour, then the training crashes with a NaN loss, and you realize the checkpoint was bad. Now you've wasted 90 minutes of GPU time. The solution is checksumming all saved data and verifying checksums on load. This adds I/O overhead but saves you from subtle data corruption issues.

The real kicker is that these costs are invisible until something fails. Your spot training works perfectly for three months. Then one day you get a corrupted checkpoint, or an interruption signal is missed, or an instance is reclaimed before the graceful shutdown triggers. These failures are rare in aggregate but devastating when they occur. The infrastructure needs to be defensive, and that defensiveness has costs.

Summary

Spot instances are a free lunch for those prepared to catch them. The key principles:

  • Checkpoint every 15 minutes of training (worst-case loss when interrupted)
  • Always async upload (local save fast, S3 in background)
  • Monitor for interruption signals (AWS 2 min, GCP 30 sec)
  • Use elastic training (torchrun + rendezvous for topology changes)
  • Handle world size changes (scale learning rate proportionally)
  • Save components separately (graceful fallback on corruption)

Build this once, drop it into any PyTorch training script, and you've reclaimed 60-90% in compute costs. The complexity is real, but so is the payoff.

What questions are you hitting in your own spot training? The production edge cases are always where the story gets interesting.


Need help implementing this?

We build automation systems like this for clients every day.

Discuss Your Project