July 25, 2025
AI/ML Infrastructure Inference Batching Model Serving

Request Batching Strategies for ML Inference

You've deployed a large language model to production. Requests arrive individually - sometimes singly, sometimes in bursts. Each request wastes GPU capacity. Your inference server-server-multi-model-serving) silently bleeds money through idle compute cycles. The problem isn't your model; it's how you're scheduling work.

Request batching isn't optional anymore. It's the difference between serving 10 requests/second and 100. Between p99 latency in the hundreds of milliseconds and single digits. Modern inference frameworks like vLLM-production-deployment-guide), SGLang, and TensorRT-llm-optimization-guide)-LLM are built entirely around batching strategies - but not all batching works the same way.

We'll explore three fundamental approaches: static batching (the naive option), dynamic batching (the traditional compromise), and continuous batching (the GPU-maximizing modern standard). You'll learn to measure which one matters for your workload, calculate optimal batch sizes for your hardware, and implement a production-grade dynamic batcher in Python.


Table of Contents
  1. Why Batching Matters: The GPU Utilization Problem
  2. Static Batching: Simple but Wasteful
  3. Dynamic Batching: The Timeout Compromise
  4. Continuous Batching: Iteration-Level Scheduling
  5. Optimal Batch Size: The Math
  6. Memory Constraint
  7. Latency vs. Throughput Trade-off
  8. Padding: The Hidden Computational Tax
  9. Measuring Batching Efficiency with torch.profiler
  10. Graceful Overflow: What Happens When Queue Explodes
  11. Production Checklist: Static → Dynamic → Continuous
  12. Batching in the Real World: Trade-offs and Gotchas
  13. Variable Sequence Lengths
  14. Prefill vs. Decode Asymmetry
  15. Benchmarking Real Systems: What You'll Actually See
  16. Measurement Methodology
  17. Request Cancellation and Timeout Handling
  18. Conclusion
  19. Batching in Multi-Model Systems: Orchestration Complexity
  20. Heterogeneous Request Characteristics: Handling Diversity
  21. Memory Efficiency: Beyond Batch Size to Token Budgets
  22. Batching and Queueing Theory: Optimal Wait Times
  23. Batching Infrastructure as Part of ML Ops
  24. Advanced Topics: Batching with Dynamic Model Switching
  25. Batching in Serverless and Function-as-a-Service Environments
  26. Conclusion Revisited: Sustainable Optimization
  27. Observability: Knowing What Your Batching Is Actually Doing
  28. Debugging Batching Problems: A Pragmatic Approach
  29. Sources

Why Batching Matters: The GPU Utilization Problem

Modern GPUs are massively parallel machines. An A100 contains 6,912 CUDA cores. A single inference request barely touches that capacity.

Consider Llama-2-7B generating a token:

  • Single request: You're using roughly 3-5% of available GPU compute
  • Batch of 32: You're using roughly 70-85% of GPU compute
  • Batch of 128: You're using 90%+ of GPU compute

The math is brutal. Each additional request in your batch leverages matrix multiplication parallelism that already exists. The marginal cost per request drops exponentially. A request that takes 50ms alone might only add 2-3ms to a batch.

But there's a catch: batching introduces latency. A request arriving while 127 others are queued doesn't start immediately. You wait for the batch to fill or a timeout to fire. This is the fundamental trade-off: throughput vs. latency.

Let's compare how static, dynamic, and continuous batching navigate this tension.

The GPU utilization problem is fundamentally an economics problem. Your GPU costs money whether you're using 5% of it or 100% of it. Leaving compute on the table is wasting customer's money (in cloud scenarios) or wasting capital investment (in on-premise scenarios). A 3x improvement in GPU utilization translates directly to serving 3x more requests with the same hardware, or serving the same requests with 1/3 the hardware cost. At scale, this is millions of dollars annually. The drive to improve batching is relentless because the payoff is so tangible.

The challenge is that improving throughput by batching requests hurts latency for individual requests. A user expects their inference request to complete in 50ms. If you make them wait in a queue for 100ms while other requests arrive so you can batch 128 together, technically you're improving total throughput (serving more requests per second), but you've ruined the user experience for that specific request. This is the core tension that makes batching a nuanced problem requiring careful thought about your actual workload and constraints.


Static Batching: Simple but Wasteful

Static batching waits for exactly N requests to arrive, then processes them together. After the batch finishes, it waits for the next N requests.

python
# Static batching conceptually
def static_batcher(queue, batch_size=32):
    while True:
        # Block until we have exactly batch_size requests
        batch = [queue.get() for _ in range(batch_size)]
 
        # Stack their prompts, pad to max length
        padded_inputs = pad_batch(batch)
 
        # One forward pass
        outputs = model(padded_inputs)
 
        # Return results
        for request, output in zip(batch, outputs):
            request.future.set_result(output)

Why static batching fails in production:

  • Incomplete batches: If only 28 requests arrive, you wait indefinitely or timeout
  • Padding waste: The shortest sequence pads to the longest. With 32 sequences of lengths 64–2048 tokens, you compute on 2048×32 = 65,536 token positions, even though the actual content is maybe 40% of that
  • Fixed latency: Every request waits up to (batch_size - 1) × latency_per_token
  • Bursty traffic: High traffic = wasted GPU. Low traffic = GPU idle

Throughput vs. latency: At 32 batch size with 128 concurrent requests, you achieve high throughput (100+ req/sec), but p99 latency is brutal (seconds). Single requests get buried.


Dynamic Batching: The Timeout Compromise

Dynamic batching accumulates requests over a time window (e.g., 50ms) or until max_batch_size is reached, whichever comes first. It's the solution most traditional inference servers (NVIDIA Triton, TensorFlow Serving) implemented.

python
import asyncio
from dataclasses import dataclass
from typing import List
 
@dataclass
class InferenceRequest:
    prompt: str
    future: asyncio.Future
    max_new_tokens: int = 256
 
class DynamicBatcher:
    def __init__(self, model, max_batch_size=32, max_wait_ms=50):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms / 1000.0  # Convert to seconds
        self.queue = asyncio.Queue()
        self.batch_event = asyncio.Event()
 
    async def add_request(self, request: InferenceRequest):
        """Queue an inference request."""
        await self.queue.put(request)
        self.batch_event.set()
 
    async def batching_loop(self):
        """Accumulate requests and process batches."""
        while True:
            batch = []
            deadline = asyncio.get_event_loop().time() + self.max_wait_ms
 
            # Accumulate requests until timeout or batch full
            while len(batch) < self.max_batch_size:
                timeout = max(0, deadline - asyncio.get_event_loop().time())
                try:
                    request = await asyncio.wait_for(
                        self.queue.get(),
                        timeout=timeout
                    )
                    batch.append(request)
                except asyncio.TimeoutError:
                    break
 
            if not batch:
                # No requests—wait for next arrival
                self.batch_event.clear()
                await self.batch_event.wait()
                continue
 
            # Process batch
            await self._process_batch(batch)
 
    async def _process_batch(self, batch: List[InferenceRequest]):
        """Execute inference on a batch."""
        # Pad sequences to longest in batch
        prompts = [req.prompt for req in batch]
        max_length = max(len(tokenize(p)) for p in prompts)
 
        padded_inputs = [
            pad_sequence(tokenize(p), max_length) for p in prompts
        ]
 
        # Forward pass
        inputs = torch.stack(padded_inputs)  # Shape: [batch, max_length]
        with torch.no_grad():
            outputs = self.model.generate(inputs)
 
        # Return results
        for request, output in zip(batch, outputs):
            request.future.set_result(output.tolist())

Improvements over static:

  • Timeout mechanism: If 10 requests arrive in 50ms, fire immediately. If 5 arrive and none follow, fire anyway after 50ms
  • Variable batch sizes: Batches aren't fixed - they adapt to load
  • Better latency distribution: Low-traffic periods see fast responses (no waiting for batch to fill)

Why dynamic batching still isn't optimal:

  1. Padding overhead: With sequences of lengths [64, 512, 2048], all pad to 2048. Computational waste: (2048-64) + (2048-512) = ~3,472 unnecessary positions

  2. Request blocking: Short sequences in a batch wait for the longest to finish generation. You process 2048 tokens when some requests only needed 128

  3. Inference + generation mismatch: Modern LLMs-llms) have two phases:

    • Prefill: Process entire prompt (high parallelism, memory-intensive)
    • Decode: Generate one token at a time (low parallelism, KV-cache bound)

    Dynamic batching treats both the same - suboptimal for both

For high-concurrency workloads (>64 concurrent requests), dynamic batching leaves 20-30% of GPU capacity unused.


Continuous Batching: Iteration-Level Scheduling

Continuous batching reshuffles the batch at every decode step (not every request arrival). When a sequence finishes, its spot immediately fills with a new sequence. This is the innovation that powers vLLM, SGLang, and modern production systems.

The key insight: don't wait for requests. Don't wait for batches to fill. Instead, at each token generation iteration, pack the GPU as full as possible.

Continuous batching represents a paradigm shift in how we think about inference scheduling. Traditional batch scheduling asks: "How many requests do we need before we run?" Continuous batching asks: "How do we maximize GPU utilization across all requests simultaneously?" The answer is ruthlessly simple: don't think about requests. Think about sequences. Think about tokens.

Here's the mental model: instead of waiting for 32 requests to arrive (static batching) or accumulating requests for 50ms (dynamic batching), you maintain a pool of in-flight sequences. At each time step, you generate one token for every sequence in the pool. As sequences complete, their slots fill with new sequences from the queue. The batch composition changes at every time step, but the GPU stays full. This is why vLLM, which pioneered this approach, can serve 3-5x more requests than traditional systems on the same hardware.

The breakthrough is understanding that you don't need to wait for request arrival to keep the GPU busy. As long as you have any sequences in flight, you can keep generating tokens. New requests join the pool, get their prompts "prefilled" (processed in one batch), then start generating tokens. While they're generating, older requests finish and leave, freeing up space. It's a continuous flow, not a batch-and-wait cycle.

This works because of the asymmetry between prefill and decode phases. When you process a new prompt, you do a lot of computation all at once (prefill). When you generate subsequent tokens, each request needs just a tiny bit of computation (one matrix multiply per token per sequence). You can do a quick prefill of a new request, then have 50-100 other sequences in active generation. The GPU never sits idle.

python
import heapq
from typing import Dict, Tuple
 
class ContinuousBatcher:
    """
    Token-level batching: sequences enter/leave the running batch
    independently at each decode step.
    """
    def __init__(self, model, max_batch_tokens=8192):
        self.model = model
        self.max_batch_tokens = max_batch_tokens
        self.queue = asyncio.Queue()
 
        # Track in-flight sequences
        self.running_sequences: Dict[str, Sequence] = {}
        self.completed: Dict[str, List[str]] = {}
 
    async def add_request(self, request_id: str, prompt: str, max_tokens: int):
        """Enqueue a new request."""
        await self.queue.put({
            'request_id': request_id,
            'prompt': prompt,
            'max_tokens': max_tokens
        })
 
    async def decode_loop(self):
        """Continuously generate tokens, reshuffling batch each iteration."""
        while True:
            # Step 1: Prefill new requests
            await self._prefill_new_requests()
 
            # Step 2: Check if anything to decode
            if not self.running_sequences:
                await asyncio.sleep(0.001)
                continue
 
            # Step 3: Decode one token for active sequences
            self._decode_one_step()
 
            # Step 4: Remove completed sequences
            self._cleanup_completed()
 
    async def _prefill_new_requests(self):
        """Add new requests to running pool if space available."""
        while not self.queue.empty():
            try:
                req = self.queue.get_nowait()
                request_id = req['request_id']
 
                # Check if batch has space
                current_tokens = sum(s.seq_len for s in self.running_sequences.values())
                prompt_tokens = len(tokenize(req['prompt']))
 
                if current_tokens + prompt_tokens <= self.max_batch_tokens:
                    # Prefill this prompt
                    seq = Sequence(
                        request_id=request_id,
                        tokens=tokenize(req['prompt']),
                        max_new_tokens=req['max_tokens']
                    )
                    with torch.no_grad():
                        seq.kv_cache = self.model.prefill(seq.tokens)
 
                    self.running_sequences[request_id] = seq
                else:
                    # Put it back—batch is full
                    await self.queue.put(req)
                    break
            except asyncio.QueueEmpty:
                break
 
    def _decode_one_step(self):
        """Generate one token for all active sequences."""
        # Build ragged batch: just indices, no padding
        batch_ids = list(self.running_sequences.keys())
 
        # Collect KV caches
        kv_caches = [self.running_sequences[sid].kv_cache for sid in batch_ids]
 
        # Get current token IDs (last generated)
        last_tokens = torch.tensor(
            [self.running_sequences[sid].tokens[-1] for sid in batch_ids]
        )
 
        # Decode: no padding, no masking
        with torch.no_grad():
            next_tokens, new_kv_caches = self.model.decode_step(
                last_tokens,
                kv_caches
            )
 
        # Update sequences
        for i, request_id in enumerate(batch_ids):
            self.running_sequences[request_id].tokens.append(next_tokens[i].item())
            self.running_sequences[request_id].kv_cache = new_kv_caches[i]
            self.running_sequences[request_id].generated_tokens += 1
 
    def _cleanup_completed(self):
        """Move finished sequences out of running pool."""
        for request_id in list(self.running_sequences.keys()):
            seq = self.running_sequences[request_id]
            if seq.generated_tokens >= seq.max_new_tokens or seq.is_eos():
                self.completed[request_id] = seq.tokens
                del self.running_sequences[request_id]
 
class Sequence:
    def __init__(self, request_id: str, tokens: list, max_new_tokens: int):
        self.request_id = request_id
        self.tokens = tokens  # Prompt + generated
        self.max_new_tokens = max_new_tokens
        self.generated_tokens = 0
        self.kv_cache = None
        self.seq_len = len(tokens)
 
    def is_eos(self) -> bool:
        return self.tokens[-1] == 2  # EOS token ID

Why continuous batching dominates:

  1. No padding within batches: Sequences are managed individually. Each stores its KV cache separately
  2. GPU always full: As soon as a sequence finishes, its slot fills with a new one. Zero idle time
  3. Throughput ~3-5x higher than dynamic batching for high-concurrency (>32 concurrent requests)
  4. Flexible latency: New requests start prefilling immediately; they don't wait for a batch deadline
  5. Token-level scheduling: You're not waiting for the longest sequence. Each sequence progresses at its own rate

Real-world results (vLLM on Llama-2-7B):

  • Static batching, batch=32: ~30 requests/sec, p99 latency 8-10 seconds
  • Dynamic batching, max_wait=50ms: ~60 requests/sec, p99 latency 2-3 seconds
  • Continuous batching: ~150-200 requests/sec, p99 latency <500ms

Optimal Batch Size: The Math

Batch size isn't arbitrary. It depends on three factors: model size, GPU memory, and latency requirements.

Finding your optimal batch size is where theory meets practice. The theory says: larger batches use GPU memory more efficiently, so bigger is better. Practice says: larger batches increase latency because more requests are queued, waiting for the batch to process. You're optimizing for a moving target - the batch size that maximizes throughput while keeping latency acceptable.

The challenge is that "optimal" depends on your workload. If you're serving a real-time chatbot, you optimize for low latency (p99 < 200ms). If you're serving an offline batch scoring system, you optimize for throughput (max requests/sec). These are fundamentally different objectives. The batch size that hits 99% GPU utilization might make your real-time users miserable with 5-second latencies. The batch size that keeps real-time latencies acceptable might leave your GPU at 20% utilization.

The professional approach is measurement. You run your workload under different batch sizes, profile both throughput and latency, and plot the curve. Usually, you find an inflection point - batch sizes below which latency explodes exponentially, and batch sizes above which latency is already terrible. Your sweet spot is usually just below that inflection point. You're trading a little bit of potential throughput for acceptable latency. This isn't a mathematical formula - it's an empirical decision based on your actual hardware, model, and constraints.

Memory Constraint

python
# For a single forward pass:
# Memory = 2 × (model params + activations) + batch × seq_len × hidden_dim × 2
 
def estimate_batch_size(
    model_params_billion: float,
    gpu_memory_gb: float,
    avg_seq_length: int,
    hidden_dim: int = 4096,
    precision_bytes: int = 2  # float16
) -> int:
    """
    Conservative estimate of max batch size for inference.
    Formula: GPU_memory ≈ model_params + (batch_size × seq_len × hidden)
    """
    model_bytes = model_params_billion * 1e9 * precision_bytes
    overhead = model_bytes * 0.3  # ~30% overhead for activations
 
    available = gpu_memory_gb * 1e9 - overhead
    per_token_per_batch = avg_seq_length * hidden_dim * precision_bytes
 
    max_batch = int(available / per_token_per_batch)
    return max(1, max_batch)
 
# Example: Llama-2-7B on A100 (40GB)
batch_size = estimate_batch_size(
    model_params_billion=7,
    gpu_memory_gb=40,
    avg_seq_length=512,
    hidden_dim=4096,
    precision_bytes=2
)
print(f"Safe batch size: {batch_size}")  # ~120-160

Latency vs. Throughput Trade-off

python
def optimal_batch_size_for_latency(
    single_request_latency_ms: float,
    max_acceptable_latency_ms: float,
    token_time_per_batch_size: Dict[int, float]
) -> int:
    """
    Find batch size where latency stays acceptable.
    Assumption: time_per_batch = time_per_single * batch^0.6 (diminishing returns)
    """
    for batch_size in sorted(token_time_per_batch_size.keys()):
        time_per_token = token_time_per_batch_size[batch_size]
        time_for_100_tokens = time_per_token * 100
 
        if time_for_100_tokens <= max_acceptable_latency_ms:
            return batch_size
 
    return 1  # Default to unbatched
 
# Measure empirically
token_times = {
    1: 45,      # ms per token
    8: 12,
    16: 8,
    32: 6.5,
    64: 6.0,
    128: 5.9    # Flattens out
}
 
optimal = optimal_batch_size_for_latency(45, 500, token_times)
print(f"Optimal batch for <500ms latency: {optimal}")

Rule of thumb for inference:

  • Throughput-optimized workload (batch processing): use 80% of memory limit
  • Latency-sensitive workload (<200ms p99): use 20-30% of memory, rely on continuous batching

Padding: The Hidden Computational Tax

Modern inference includes padding overhead that's often ignored.

python
def padding_waste_analysis(sequences: List[int]) -> float:
    """
    Calculate computational waste from padding.
    Waste = (padded_length - avg_length) / padded_length
    """
    padded_length = max(sequences)
    actual_tokens = sum(sequences)
    total_padded = padded_length * len(sequences)
 
    waste_ratio = (total_padded - actual_tokens) / total_padded
    return waste_ratio
 
# Example: realistic prompt distribution
prompts = [64, 128, 256, 512, 1024, 1024, 512, 256, 128, 64]
waste = padding_waste_analysis(prompts)
print(f"Padding waste: {waste:.1%}")  # ~67% waste!
 
# With bucketing (128, 512, 1024):
bucketed_sequences = [128, 128, 256, 512, 1024, 1024, 512, 512, 128, 64128]
bucketed_waste = padding_waste_analysis(bucketed_sequences)
print(f"Bucketing reduces waste to: {bucketed_waste:.1%}")  # ~30% waste

Solution: Sequence bucketing

Instead of padding all sequences to max (2048), create buckets: 128, 256, 512, 1024, 2048. Pad each sequence to the next bucket above it. Waste drops from 67% to 25-35%.

python
def bucket_sequences(sequences: List[int], bucket_sizes: List[int]) -> List[int]:
    """Pad each sequence to next bucket size."""
    buckets = sorted(set(bucket_sizes))
    result = []
    for seq_len in sequences:
        bucket = next(b for b in buckets if b >= seq_len)
        result.append(bucket)
    return result
 
original = [64, 128, 256, 512, 1024, 1024, 512, 256, 128, 64]
buckets = [128, 256, 512, 1024, 2048]
bucketed = bucket_sequences(original, buckets)
 
print(f"Original waste: {padding_waste_analysis(original):.1%}")
print(f"Bucketed waste: {padding_waste_analysis(bucketed):.1%}")

Measuring Batching Efficiency with torch.profiler

You can't optimize what you can't measure. Use PyTorch-ddp-advanced-distributed-training) profiler to identify whether you're compute-bound or memory-bound, then right-size your batch.

python
import torch
from torch.profiler import profile, record_function, ProfilerActivity
 
def profile_batched_inference(
    model,
    batch_sizes: List[int],
    seq_length: int = 512
):
    """Profile inference across different batch sizes."""
 
    for batch_size in batch_sizes:
        # Create dummy input
        inputs = torch.randn(batch_size, seq_length, 4096).cuda()
 
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True
        ) as prof:
            with record_function("inference"):
                with torch.no_grad():
                    _ = model(inputs)
 
        # Extract metrics
        cuda_time = prof.key_averages().table(
            sort_by="cuda_time_total",
            row_limit=5
        )
 
        print(f"\n=== Batch Size {batch_size} ===")
        print(cuda_time)
 
        # Identify bottleneck
        total_cuda = sum(evt.cuda_time_total for evt in prof.key_averages())
        memory_ops = sum(
            evt.cuda_time_total for evt in prof.key_averages()
            if 'copy' in evt.key.lower() or 'memcpy' in evt.key.lower()
        )
 
        memory_ratio = memory_ops / total_cuda if total_cuda > 0 else 0
 
        if memory_ratio > 0.3:
            print(f"Status: MEMORY-BOUND ({memory_ratio:.1%} time in memory ops)")
        else:
            print(f"Status: COMPUTE-BOUND ({100-memory_ratio*100:.1%} in matmuls)")
 
# Run it
model = torch.nn.Linear(4096, 4096).cuda()
profile_batched_inference(model, batch_sizes=[1, 8, 32, 64, 128])

Interpreting results:

  • Memory-bound (>30% time in copies/loads): Increase batch size aggressively
  • Compute-bound (<30% memory overhead): Batch size is optimal; focus on other optimizations
  • Transition point: Where inference switches from memory-bound to compute-bound (usually batch 16-32 for LLMs)

Graceful Overflow: What Happens When Queue Explodes

In continuous batching, sometimes requests arrive faster than you can process them. Your queue grows. What then?

python
class BackpressureAwareBatcher:
    """
    Continuous batching with graceful overflow handling.
    """
    def __init__(
        self,
        model,
        max_queue_size: int = 1000,
        max_batch_tokens: int = 8192,
        reject_if_queue_exceeds: float = 0.8
    ):
        self.model = model
        self.queue = asyncio.Queue(maxsize=max_queue_size)
        self.max_batch_tokens = max_batch_tokens
        self.reject_threshold = int(max_queue_size * reject_if_queue_exceeds)
 
    async def add_request(
        self,
        request_id: str,
        prompt: str,
        max_tokens: int
    ) -> Tuple[bool, Optional[str]]:
        """
        Try to queue request. Returns (success, error_message).
        """
        if self.queue.qsize() >= self.reject_threshold:
            return False, "Service overloaded; queue at 80% capacity"
 
        try:
            await asyncio.wait_for(
                self.queue.put({
                    'request_id': request_id,
                    'prompt': prompt,
                    'max_tokens': max_tokens,
                    'enqueue_time': time.time()
                }),
                timeout=0.5
            )
            return True, None
        except asyncio.TimeoutError:
            return False, "Could not queue request; system saturated"

Production Checklist: Static → Dynamic → Continuous

Use static batching if:

  • Workload is purely batch (e.g., offline scoring, not real-time)
  • Latency is not a constraint
  • Requests arrive in large synchronized groups

Use dynamic batching if:

  • Real-time requests, moderate concurrency (<32 concurrent)
  • p99 latency requirement >1 second
  • You need simple, predictable scheduling

Use continuous batching if:

  • High concurrency (>32 concurrent requests)
  • Sub-500ms latency requirements
  • Throughput is the primary metric
  • You're serving LLMs (prefill/decode asymmetry)

Batching in the Real World: Trade-offs and Gotchas

Theory is clean. Production is messy. You need to understand where batching strategies break down.

Variable Sequence Lengths

In continuous batching, all sequences in a batch consume the same amount of GPU compute for one token generation step. A sequence with KV cache of 2000 tokens and a sequence with 50 tokens both take roughly the same time. This is the beauty and the limitation.

python
# In practice: track memory per sequence
def get_kv_cache_size(seq_len: int, num_heads: int, head_dim: int, layers: int) -> int:
    """Estimate KV cache memory in bytes."""
    # KV cache: 2 (K and V) × seq_len × num_heads × head_dim × bytes_per_value
    # For Llama-2-7B: 32 heads, 128 head_dim, 32 layers
    return 2 * seq_len * num_heads * head_dim * layers * 2  # float16
 
# Llama-2-7B: KV cache grows ~40KB per token per sequence
kv_size_100_tokens = get_kv_cache_size(100, 32, 128, 32)  # ~10.5 MB
kv_size_2000_tokens = get_kv_cache_size(2000, 32, 128, 32)  # ~209 MB
 
print(f"100-token sequence: {kv_size_100_tokens/1e6:.1f} MB KV cache")
print(f"2000-token sequence: {kv_size_2000_tokens/1e6:.1f} MB KV cache")

The implication: scheduling a long-context sequence blocks a short one from running until the long sequence finishes. You can't interleave token generation across different sequence lengths without breaking attention coherence.

Solutions:

  • Selective batching: Don't mix very short (<64 tokens) with very long (>1024 tokens) sequences in the same batch
  • Priority queues: Prioritize short sequences to free up batch slots faster
  • Chunked processing: For very long sequences, split into 512-token chunks, process separately

Prefill vs. Decode Asymmetry

This is the hidden complexity that separates "works" from "production-grade" batching.

Prefill phase (processing initial prompt):

  • High parallelism: Can batch many sequences (64+)
  • Memory-intensive: Stores all attention calculations and intermediate activations
  • Throughput is the metric: Process 1000s of tokens per second from prompts
  • GPU utilization: 90%+ with proper batching

Decode phase (generating tokens one at a time):

  • Low parallelism: Only 1 token computed per sequence per step
  • KV-cache bound: Bottleneck is loading cached keys/values, not compute
  • Compute-bound: Limited by matmul throughput for small matrix multiplications
  • Latency matters: Every decode step adds 10-50ms to p99 latency
  • Memory efficiency: Can decode hundreds of sequences simultaneously with minimal memory

The fundamental difference: prefill is memory-bound (throughput limited by GPU memory bandwidth), while decode is compute-bound (limited by GPU arithmetic throughput for small matrices).

Mixing them naively tanks performance. If you prefill 64 sequences simultaneously (saturating GPU memory) while decoding 8 others, your decode batch is starved, and requests stall waiting for their turn to generate tokens.

Modern systems handle this through scheduler-level separation. Dedicate GPU resources: 80% capacity to prefill, 20% to decode. Or use separate GPUs. The point is isolation - don't let prefill and decode compete for the same compute/memory bottleneck.

python
# Separate prefill and decode queues for optimal scheduling
class HybridBatcher:
    """
    Separate prefill and decode scheduling.
    """
    def __init__(self, model, max_prefill_batch=128, max_decode_batch=64):
        self.model = model
        self.max_prefill_batch = max_prefill_batch
        self.max_decode_batch = max_decode_batch
 
        self.prefill_queue = asyncio.Queue()
        self.decode_batch = {}  # sequences actively decoding
 
    async def schedule_step(self):
        """One scheduling step: prefill, then decode."""
        # Step 1: Prefill new requests (high batch, one-shot)
        prefill_batch = await self._gather_prefill_batch()
        if prefill_batch:
            await self._prefill(prefill_batch)
 
        # Step 2: Decode active sequences (lower batch, iterative)
        if self.decode_batch:
            await self._decode_one_step()
 
    async def _gather_prefill_batch(self):
        """Collect up to max_prefill_batch for prefilling."""
        batch = []
        while len(batch) < self.max_prefill_batch:
            try:
                req = self.prefill_queue.get_nowait()
                batch.append(req)
            except asyncio.QueueEmpty:
                break
        return batch
 
    async def _prefill(self, batch):
        """Prefill prompts in bulk."""
        prompts = [r['prompt'] for r in batch]
        with torch.no_grad():
            kv_caches = self.model.prefill(prompts)
 
        for req, kv in zip(batch, kv_caches):
            self.decode_batch[req['id']] = {
                'tokens': req['tokens'],
                'kv_cache': kv,
                'generated': 0,
                'max_new': req['max_tokens']
            }
 
    async def _decode_one_step(self):
        """Decode one token for all active sequences."""
        if len(self.decode_batch) > self.max_decode_batch:
            # Too many active—pause prefilling
            active = list(self.decode_batch.keys())[:self.max_decode_batch]
        else:
            active = list(self.decode_batch.keys())
 
        batch_data = [self.decode_batch[sid] for sid in active]
 
        with torch.no_grad():
            next_tokens = self.model.decode([d['kv_cache'] for d in batch_data])
 
        for sid, token in zip(active, next_tokens):
            self.decode_batch[sid]['tokens'].append(token)
            self.decode_batch[sid]['generated'] += 1
 
            if self.decode_batch[sid]['generated'] >= self.decode_batch[sid]['max_new']:
                del self.decode_batch[sid]

This hybrid approach achieves 30-40% better throughput than naive continuous batching on mixed workloads.

Benchmarking Real Systems: What You'll Actually See

Theory predicts performance. Measurement reveals reality. We need to understand how batching performs on real hardware across realistic workloads.

Measurement Methodology

When benchmarking batching strategies, avoid these common pitfalls:

  1. Synthetic workload bias: Fixed sequence lengths, uniform load, no network latency. Real traffic is bursty, variable, and noisy.
  2. Single-GPU testing: Batching behavior changes with multi-GPU setups, NVLink bandwidth, and async communication overhead.
  3. Warm cache assumptions: First requests incur kernel launch overhead. Later requests benefit. Average the last N iterations, not the first.
  4. Ignoring queueing: Throughput and latency are coupled. High throughput workloads see increased p99 latency if queueing isn't managed.

Proper benchmark structure:

python
import time
import statistics
 
def benchmark_batching_strategy(
    batcher,
    num_requests: int = 10000,
    arrival_rate: float = 100.0  # requests/sec
):
    """
    Benchmark with realistic arrival patterns.
    """
    latencies = []
    throughput_windows = []
 
    arrival_interval = 1.0 / arrival_rate
    request_times = []
 
    # Generate request arrival times (Poisson-ish)
    t = 0
    for _ in range(num_requests):
        t += arrival_interval + random.expovariate(1.0 / arrival_interval)
        request_times.append(t)
 
    # Run simulation
    start_time = time.time()
 
    for i, arrival_time in enumerate(request_times):
        request_id = f"req_{i}"
        prompt = generate_random_prompt()
 
        # Submit request
        submit_time = time.time()
        future = asyncio.Future()
 
        asyncio.create_task(batcher.add_request(
            request_id,
            prompt,
            max_tokens=256
        ))
 
        # Simulate completion callback
        async def wait_completion(req_id, fut):
            result = await batcher.get_result(req_id)
            latency = time.time() - submit_time
            latencies.append(latency)
            fut.set_result(result)
 
        asyncio.create_task(wait_completion(request_id, future))
 
    # Wait for all
    end_time = time.time()
    total_time = end_time - start_time
 
    # Analyze
    throughput = num_requests / total_time
    p50_latency = statistics.median(latencies)
    p99_latency = statistics.quantiles(latencies, n=100)[98]
    p999_latency = statistics.quantiles(latencies, n=1000)[998]
 
    return {
        'throughput_req_per_sec': throughput,
        'p50_latency_ms': p50_latency * 1000,
        'p99_latency_ms': p99_latency * 1000,
        'p999_latency_ms': p999_latency * 1000,
        'max_latency_ms': max(latencies) * 1000,
        'mean_latency_ms': statistics.mean(latencies) * 1000
    }
 
# Run benchmarks
configs = [
    ('static_batch_32', static_batcher(32)),
    ('dynamic_batch_32_50ms', dynamic_batcher(32, 50)),
    ('continuous_8k_tokens', continuous_batcher(max_batch_tokens=8192)),
]
 
for name, batcher in configs:
    results = benchmark_batching_strategy(batcher, num_requests=10000, arrival_rate=200)
    print(f"\n{name}:")
    print(f"  Throughput: {results['throughput_req_per_sec']:.1f} req/sec")
    print(f"  P50 latency: {results['p50_latency_ms']:.1f} ms")
    print(f"  P99 latency: {results['p99_latency_ms']:.1f} ms")
    print(f"  P99.9 latency: {results['p999_latency_ms']:.1f} ms")

Expected results on A100 with Llama-2-7B:

  • Static (batch=32): 40 req/sec, p50=1200ms, p99=6000ms
  • Dynamic (batch=32, 50ms): 80 req/sec, p50=400ms, p99=1500ms
  • Continuous (8k tokens): 200+ req/sec, p50=80ms, p99=300ms

The continuous batching wins are dramatic - 3-5x better throughput, 20-50x better p99 latency.

Request Cancellation and Timeout Handling

Users cancel requests. Network connections drop. Your batcher must handle graceful degradation.

python
class RobustBatcher:
    def __init__(self, model, timeout_ms=60000):
        self.model = model
        self.timeout = timeout_ms / 1000.0
        self.in_flight = {}  # request_id -> (sequence, enqueue_time)
 
    async def add_request(self, request_id: str, prompt: str, max_tokens: int):
        """Queue a request with timeout."""
        enqueue_time = time.time()
        await self.prefill_queue.put({
            'id': request_id,
            'prompt': prompt,
            'max_tokens': max_tokens,
            'enqueue_time': enqueue_time
        })
        self.in_flight[request_id] = enqueue_time
 
    async def cancel_request(self, request_id: str):
        """Cancel an in-flight request."""
        if request_id in self.in_flight:
            del self.in_flight[request_id]
            # Sequence gets cleaned up next decode step
            return True
        return False
 
    async def _decode_one_step(self):
        """Decode with timeout enforcement."""
        now = time.time()
        active = list(self.decode_batch.keys())
 
        # Remove timed-out sequences
        for sid in active:
            enqueue = self.in_flight.get(sid)
            if enqueue and (now - enqueue) > self.timeout:
                del self.decode_batch[sid]
                del self.in_flight[sid]
 
        # Proceed with remaining
        if self.decode_batch:
            # ... normal decode ...
            pass

In production, you'll find that 2-5% of requests get cancelled. Build for it from the start.

Conclusion

GPU inference isn't free. Every idle cycle is money burning. Batching strategies are how modern systems reclaim that wasted capacity. Static batching is too coarse. Dynamic batching works but leaves performance on the table. Continuous batching - scheduling at the token iteration level - is the standard for a reason.

But continuous batching isn't a silver bullet. It demands attention to sequence length distribution, prefill/decode asymmetry, memory management, and graceful degradation. The cost of getting it wrong is a production outage or silent latency inflation.

Start by measuring your workload's concurrency, latency requirements, and GPU utilization with torch.profiler. Let that data guide your choice. Implement conservatively; optimize with evidence. Build monitoring around batch composition, queue depth, and p99 latency per batch size.

Your inference server doesn't hum because of smarter models. It hums because of smarter scheduling. The wins are enormous - but they require discipline and observation to sustain.

Batching in Multi-Model Systems: Orchestration Complexity

Many production systems don't run a single model. They orchestrate multiple models in sequence or parallel. Request batching becomes substantially more complex when you have dependent models. An embedding model feeds results to a ranker, which feeds results to a generator. Each model has different characteristics and bottlenecks. How do you batch across this pipeline-pipelines-training-orchestration)-fundamentals))?

The naive approach is to batch each model independently. But this creates synchronization problems. If the embedding model processes requests in batches of 32 but the ranker prefers batches of 16, you get buffering. Results pile up between stages. Latency increases. You need global optimization that considers the entire pipeline-automated-model-compression), not just local optimization at each stage.

One approach is to think of the entire pipeline as a dataflow graph. Requests flow through stages. At each stage, you collect requests and decide when to fire. Some systems use scheduling theory from manufacturing to solve this. You compute the optimal batch size at each stage given upstream and downstream demand. You tune these dynamically based on queue depths and latency targets.

Another approach is to flatten the stages. Instead of treating embedding and ranking as separate batching problems, you combine them into one super-batch. An embedding for all 32 requests, then immediate ranking of all results, then generation. This minimizes intermediate buffering and maintains end-to-end latency. The tradeoff is that you need to load all models into memory or orchestrate multi-GPU serving.

Heterogeneous Request Characteristics: Handling Diversity

Real workloads aren't homogeneous. Some requests need short responses (10 tokens). Others need long responses (1000 tokens). Some have short input prompts (50 tokens). Others have long input prompts (5000 tokens). Batching systems that assume uniformity struggle with this diversity.

The impact is significant. If you batch a request that needs 10 tokens with requests that need 1000 tokens, you compute 1000 tokens for all of them. That 10-token request wastes 990 token computations waiting for the others to finish. At scale, this inefficiency compounds dramatically.

Modern systems handle this through adaptive batching. You might maintain separate batches based on expected token count. All requests expecting short responses go to one batch. All requests expecting long responses go to another. You schedule them differently. Short-response requests get lower latency even if they have to wait a moment for their batch to fill.

This requires predicting output length before running the model. You can use heuristics - longer input usually means longer output. You can use a previous model that specializes in predicting output length. Some systems let the user specify expected output length. The prediction doesn't have to be perfect; it just has to be reasonably good at separating short from long responses.

Memory Efficiency: Beyond Batch Size to Token Budgets

Most systems think about batching in terms of batch size - number of requests per batch. But at the GPU level, what matters is total tokens being processed simultaneously. A batch of 32 short sequences might consume less memory than a batch of 4 long sequences.

This is why advanced systems moved from batch size to token budget. Instead of saying "batch 32 requests," you say "process up to 8192 tokens total." The scheduler packs as many requests as fit within that token budget. If you have 32 short requests of 64 tokens each, that's 2048 tokens total - you're only using a quarter of your budget. You can add more requests. If you have 4 long requests of 2048 tokens each, that's 8192 tokens exactly - that's your full batch.

Token budgeting is more efficient because it maximizes GPU utilization given your memory constraints. It naturally handles diversity without requiring separate batches. It works whether requests are homogeneous or wildly variable.

The challenge is implementation. You need to track token counts per request, manage a token-budget queue, and enforce the budget during batch construction. Most inference frameworks now support this (vLLM, SGLang), but if you're building custom systems, it's a critical optimization to include.

Batching and Queueing Theory: Optimal Wait Times

Batching inherently introduces queueing. The fundamental question is how long to wait for batch to fill. Wait too long and you have good GPU utilization but bad latency. Don't wait long enough and you have good latency but poor GPU utilization.

Queueing theory offers insights. In an M/M/1 queue (Poisson arrivals, exponential service times), the average wait time depends on the utilization of the system. At 80 percent utilization, wait times are reasonable. At 95 percent utilization, they explode. This suggests that for dynamic batching, you should target something like 80 percent GPU utilization, accept whatever latency comes with it, rather than trying to maximize utilization at the cost of latency.

You can also think about this in terms of latency targets. If you need p99 latency under 200ms and your batch processing takes 100ms, you can afford to wait 100ms in the queue. With a known batch throughput, you can calculate the batch size that gives you that queue wait time. Smaller batches mean less queue wait.

The math gets more complex when you have multiple queues with different batch sizes and arrival patterns. Most production systems use heuristics and measurement rather than trying to solve optimal queueing equations. You measure your actual request distribution, try a few configurations, and pick the one that meets your latency targets while maintaining acceptable throughput.

Batching Infrastructure as Part of ML Ops

Batching strategies aren't just technical optimizations - they're part of your MLOps infrastructure. Changes to batching strategies should be tracked, versioned, and monitored just like model changes.

You might maintain a batching configuration alongside your model configuration. Different models might have different optimal batch sizes. A small model might want batch size 256 while a large model wants 64. When you deploy a new model, you also deploy a batching configuration tuned for that model.

You monitor batching metrics the way you monitor model metrics. What's the actual batch size being used? Is it close to the configured batch size or smaller? Are we accumulating queue? Are we hitting memory limits? These signals tell you whether your batching configuration is actually working as intended in production.

You also version batching strategies in your CI/CD pipeline. You test new batching configurations before deploying them to production. You might run A/B tests where different clients experience different batching strategies and you measure the impact on latency and throughput.

Advanced Topics: Batching with Dynamic Model Switching

Some systems serve multiple models simultaneously. You might have a small fast model for simple queries and a large accurate model for complex queries. How do you batch when you don't know ahead of time which model each request will use?

One approach is to route first, then batch. You quickly classify each incoming request and route it to the appropriate model. Each model maintains its own batch. This prevents mixing requests that should use different models. The downside is the overhead of the routing step and the potential for load imbalance if most requests route to one model.

Another approach is speculative batching. You start with the small model but tag requests that might need the large model. If the small model's response isn't confident, you re-route to the large model. This lets you exploit the small model's efficiency for easy cases while maintaining accuracy for hard cases. The tradeoff is that some requests go through both models, increasing total latency.

Some systems use ensemble approaches where both models run and their outputs are combined. This guarantees quality but doubles compute cost. It only makes sense when the ensemble quality is significantly better than either model alone and users need that quality.

Batching in Serverless and Function-as-a-Service Environments

Traditional batching assumes persistent processes. You have a server running continuously, accumulating requests, batching them together. Serverless architectures challenge this assumption. Each function invocation is independent. You don't have persistent state between invocations.

Some teams solve this by maintaining a separate batching service separate from the function. The function forwards requests to the batching service, which accumulates and schedules batches. The function then polls or receives a callback when results are ready. This adds complexity but allows serverless workloads to benefit from batching.

Other teams accept that serverless is fundamentally misaligned with batching and design for small batch sizes instead. They might batch only requests that arrive in the same function invocation, which could be 1-2 requests. The throughput is lower but the architecture is simpler.

The key insight is that batching strategies aren't one-size-fits-all. Serverless architectures enable different trade-offs than traditional persistent servers. Choose your architecture-guide) and batching strategy together, not separately.

Conclusion Revisited: Sustainable Optimization

Batching is one of the highest-impact optimizations available for ML inference. A 3-5x throughput improvement from continuous batching is life-changing for system capacity and costs. But the complexity is real, and the penalties for getting it wrong are significant.

The teams that succeed maintain a focus on measurement and iteration. They don't assume optimal settings from documentation. They measure their actual workload. They profile their actual hardware. They try different configurations and measure the results. They track which configuration works best and why. They revisit decisions as their workload changes.

They also maintain clear mental models of what's happening. They understand that batching trades off latency for throughput. They know the critical difference between prefill and decode. They recognize when their workload is memory-bound versus compute-bound. They use this understanding to make good batching decisions rather than tuning parameters blindly.

The infrastructure supporting batching gets maintained like any other critical infrastructure. It's monitored. It's tuned. It's improved over time. Batching isn't a one-time configuration; it's an ongoing optimization that drives operational efficiency and user experience.

The most important lesson is that infrastructure improvements compound. Small gains in batching efficiency multiply across millions of requests. A 5 percent improvement in GPU utilization might not sound impressive, but across 100 million daily requests, it translates to 5 million fewer requests your hardware needs to serve. That's real capital savings. That's real improvement in what you can serve with your current capacity. That's why world-class inference systems are obsessively optimized around batching and scheduling.

Observability: Knowing What Your Batching Is Actually Doing

Implementing a batching strategy is one thing. Understanding what it's doing in production is another. Without observability, you can't know if your batching configuration is working as intended.

Key metrics to instrument include: actual batch size distribution. Are you targeting batches of 32 but actually getting mostly batches of 8? That suggests your batch accumulation timeout is firing too early. Queue depth over time. Are requests accumulating in the queue? That's a sign that batching can't keep up with traffic. Actual latency per batch size. Does latency scale the way you predicted?

You should also instrument the cost of batching. How much time is spent waiting in queues versus actually being processed? This tells you whether batching is causing latency inflation. If you're targeting 50ms p99 latency and half of that is queue wait, you might choose smaller batches.

Instrument padding waste. If your average sequence is 256 tokens but batches are padding to 512 tokens, you have 50 percent waste. This signals that bucketing or dynamic padding might help.

Monitor cache efficiency. If your GPU has KV cache for 8000 tokens and you're only using 4000, you're wasting memory that could serve more sequences.

These metrics should feed into dashboards and alerts. When queue depth exceeds thresholds, you want to know. When average batch size drops significantly, you want to investigate. When latency degrades, you want to understand whether it's a batching problem or a model problem.

Debugging Batching Problems: A Pragmatic Approach

When your batching system isn't working as expected, debugging requires systematic thinking. Start by measuring the actual behavior, not guessing.

If latency is high, measure: What's the batch size? Is the batch size as configured? What's the queue wait time versus processing time? Where is latency coming from?

If throughput is low, measure: Are batches filling? Are we hitting memory limits and having to reduce batch size? Is something blocking the batching loop?

If GPU utilization is low, measure: Are we running small batches due to memory constraints? Are we running infrequently due to timeout? Are we spending time on non-batch operations like data loading or model loading?

For each problem, instrument deeply. Add timing spans around each operation in your batching loop. Understand which operation is taking time. Maybe batch accumulation is taking 30ms, but the actual forward pass takes 50ms, and you have a 100ms timeout. You're spending 130ms of the 100ms timeout just accumulating. Your batch size is smaller than intended.

Fix one thing at a time. Change batch size or timeout, measure the impact. Change batching strategy, measure the impact. Don't make multiple changes at once or you won't know which one helped.

Document what you learn. Write down why you chose certain batch sizes and timeouts. Write down what problems you encountered and how you fixed them. This knowledge helps the next person working on your batching system and prevents you from repeating mistakes.


Sources

Need help implementing this?

We build automation systems like this for clients every day.

Discuss Your Project