Secure ML Serving: TLS, Input Validation, and Rate Limiting
You've just deployed your ML model to production. It's fast, accurate, and users love it. Then at 2 AM, your alerting system goes haywire: someone's hammering your inference endpoint with 50,000 requests per second, your GPU is melting, and your API certificates are about to expire. Sound familiar?
Securing ML APIs isn't just about slapping HTTPS on a FastAPI server and calling it a day. Production ML systems face unique threats: adversaries crafting malicious inputs to trigger model exploits, legitimate users accidentally sending payloads that crash your inference pipeline-pipelines-training-orchestration)-fundamentals)), and coordinated attacks designed to exhaust your GPU compute. This article walks you through the layered security strategy that separates toy projects from production systems.
We'll cover TLS termination, input validation strategies, rate limiting that understands compute costs, WAF rules tailored for ML workloads, and DoS protection patterns. By the end, you'll have a blueprint for building ML APIs that withstand real-world threats.
Let's ground this in a concrete scenario you've probably encountered. You're running a content moderation API. It classifies images as safe or flagged. Your SLA guarantees 100ms latency per request. On a standard A100 GPU, each inference uses about 2GB of memory and takes 80ms. You can fit roughly 8 concurrent requests on the GPU before you start queuing. Now imagine an attacker discovers they can send high-resolution images that force full-image processing instead of downsampling, pushing processing time to 250ms per request. Suddenly, your capacity drops by 60%. Send 100 such requests per second, and you've effectively taken your service offline through resource exhaustion without triggering any traditional rate limits.
This is why ML API security requires a different playbook. You can't just set a global request/second limit. You need to understand the compute cost of each request, validate inputs before they hit your GPU, and implement rate limiting that accounts for heterogeneous workload costs. An image classification request and a large language model prompt have wildly different compute costs, but traditional rate limiters treat them identically.
Throughout this guide, we'll build a layered security architecture. At the edge, TLS ensures encrypted transport and client authentication. Behind that, input validation gates prevent obviously malicious inputs from reaching your model. In the middle, compute-aware rate limiting ensures that even valid requests don't overwhelm your infrastructure. Behind that sits your model serving layer with circuit breakers and fallback-fallback) strategies. This layering means no single point of failure, and each layer stops a different class of attack before it reaches the next layer.
Table of Contents
- The Security Challenge Unique to ML APIs
- The Economics of Defense
- Defense-in-Depth Architecture
- Understanding Your Threat Model
- Security Layers: Defense in Depth
- Part 1: TLS Termination and Certificate Management
- Why TLS Matters for ML APIs
- The True Cost of Not Having Security
- Setting Up TLS with cert-manager
- Mutual TLS (mTLS) for Internal Service Communication
- TLS 1.3 Configuration in Nginx/Envoy
- Part 2: Input Validation for ML Models
- Common Validation Failures in Production
- The Validation Problem
- Schema Validation with Pydantic v2
- Shape and Dtype Checks for Tensor Inputs
- String Length Limits for LLM APIs
- Injection Attack Prevention
- Part 3: Rate Limiting Strategies
- Why Generic Rate Limiting Fails for ML
- Per-User and Per-API-Key Rate Limiting
- Token Bucket vs Sliding Window Algorithms
- GPU-Aware Rate Limits Based on Compute Cost
- Part 4: WAF Rules for ML Endpoints
- WAF Basics
- OWASP Core Rule Set Customization
- Blocking Oversized Payloads
- CloudFront and Cloudflare WAF for Public APIs
- Part 5: DoS Protection and Circuit Breakers
- Request Queuing with Backpressure
- Circuit Breakers with Istio/Envoy
- Abuse Detection Using Request Pattern Analysis
- Architecture Diagram
- Putting It All Together: Complete Example
- Summary: The Security Stack
The Security Challenge Unique to ML APIs
Traditional web services worry about three things: authentication (who are you?), authorization (what are you allowed to do?), and rate limiting (how fast can you go?). ML APIs have to worry about those, plus a fourth: adversarial inputs.
Here's the reality: your ML API sits at the intersection of three distinct threat vectors that most infrastructure teams aren't prepared to defend. First, there's the traditional API attack surface - SQL injection, authentication bypasses, rate-based denial of service. These are well-understood by security teams, and the defenses are mature. But ML APIs introduce two new problems that traditional security practices don't fully address.
The second threat vector is adversarial input manipulation. An attacker who understands your model's decision boundary can craft inputs specifically designed to fool it. Unlike a bug in your code that affects everyone equally, an adversarial input might fool only your model while looking completely legitimate to humans. Your fraud detection system classifies a crafted transaction as benign when))-ml-model-testing)-scale)-real-time-ml-features)-apache-spark))-training-smaller-models)) it should flag it. Your content moderation model misses harmful content that's been slightly perturbed. Your computer vision system misidentifies a stop sign. These aren't theoretical attacks - they're actively exploited in the wild.
The third threat vector is resource exhaustion. An attacker could hammer your inference endpoint with requests designed to maximize GPU utilization per request. They might send enormous batch sizes, request models that consume unusual amounts of memory, or submit inputs that trigger worst-case algorithmic behavior. Unlike traditional APIs where a 1MB payload takes similar time to process as a 1KB payload, ML models have explosive resource consumption curves. A slightly larger image can trigger dense matrix multiplications that drain your GPU. An attacker knowing this can craft payloads optimized for GPU exhaustion.
What makes ML API security fundamentally different is that you can't just "patch" your model. A vulnerability in your code gets fixed with a deployment-production-inference-deployment). A vulnerability in your model's decision boundary requires retraining on adversarial examples, and even then you're never fully sure the vulnerability is gone. This means your defense strategy has to shift from "prevent all attacks" to "rate limit attacks faster than they can happen."
The Economics of Defense
Let's talk about why this matters financially. Assume you're running a content moderation API that costs you $0.001 per inference on an A100 GPU. An attacker sends 1,000 adversarial image requests per second. If those requests look legitimate and your rate limiting doesn't catch them, you burn $1 per second processing them. Over 24 hours, that's $86,400 in direct compute costs. But the real cost is higher - while your infrastructure is busy processing adversarial requests, legitimate users are getting degraded service. If moderation latency hits your SLA, you owe customers credits. If your service becomes unavailable, they churn.
This is why compute-aware security isn't optional. You need to understand what each request costs, validate inputs before they burn resources, and rate-limit based on compute budget, not just request count. A small image processed against a lightweight model might use 100ms of GPU time. A large image against a heavyweight model might use 2 seconds. One request/second of small images is sustainable; one request/second of large images breaks your infrastructure. Traditional rate limiting that treats these identically is broken.
Defense-in-Depth Architecture
Your defense strategy needs layers. If any single layer fails, others still protect you. This is the principle of defense-in-depth, and it's essential for production ML systems. At the perimeter, TLS and authentication ensure you're talking to who you think you are. Behind that, input validation gates prevent obviously malicious requests from reaching compute. Next comes compute-aware rate limiting that understands the cost of each request and stops cascading failures. Finally, inside your model serving layer, circuit breakers and fallback strategies ensure one model failure doesn't tank your entire system.
The goal isn't to prevent all attacks - that's mathematically impossible. The goal is to make attacks expensive enough that they're not worth the attacker's time. When you layer defenses correctly, the attacker has to get past TLS, then input validation, then rate limiting, then circuit breakers. Each layer adds complexity and cost. Most attackers give up. The sophisticated ones who don't? You've bought time to respond while you're analyzing logs and understanding the attack pattern.
Traditional APIs process data semantically. A request to /users/123 is either valid (user 123 exists) or invalid (doesn't exist). Invalid requests fail safely. ML models work differently. They process continuous inputs and always produce an output. There's no concept of "invalid" input - only inputs where the model's output is unreliable or adversarial.
This creates unique attack vectors. An attacker can craft inputs that fool the model into making wrong predictions (adversarial examples). They can send large payloads that crash the tokenizer or cause GPU OOM (resource exhaustion). They can send carefully timed requests that expose timing side-channels to extract model weights. They can use your API as a service to generate embeddings for their own uses (model theft).
Traditional rate limiting ("X requests per second per user") doesn't capture the compute cost of inference. A simple linear classifier is cheap. A large language model is expensive. Rate limits should be aware of cost, not just request count. Some teams implement "compute tokens" - each request consumes tokens based on model size and batch size, and you can only consume X tokens per hour.
Understanding Your Threat Model
Before you build defenses, understand what you're defending against. Your threat model depends on what your model does and who might attack it.
If you're running a public API (anyone on the internet can call it), your threat model is broad: attackers with no special knowledge can try common attacks. If you're running an internal API (only authenticated employees can call it), your threat model is narrower but includes sophisticated insiders who understand your infrastructure.
Some specific threats to consider. Availability attacks (denial of service) try to make your service unavailable - they overwhelm your infrastructure with requests or craft requests that consume maximum resources. Integrity attacks try to corrupt your predictions or steal your model weights. Confidentiality attacks try to extract information about your model, training data, or users. Compliance attacks try to exploit your systems to violate regulations like GDPR or HIPAA.
For most ML APIs, availability threats are the most common. An attacker can send malicious requests faster than you can serve legitimate ones. The sophisticated part is that they don't need to overwhelm you with raw request volume - they can craft requests that consume compute inefficiently. One request that takes 100x longer than a normal request is worth 100 normal requests in terms of resource consumption.
A good threat model also considers insider threats. Could a disgruntled employee extract your model? Could a contractor leak your training data? Could an employee with access to logs see sensitive predictions? These insider scenarios are often more likely than sophisticated external attacks, yet many teams focus entirely on external threats.
Security Layers: Defense in Depth
The approach we'll cover here is "defense in depth." You don't rely on any single mechanism to stop attacks. Instead, you layer defenses so that if one is bypassed, others catch the attack.
Layer 1: TLS encryption ensures that data in transit can't be intercepted or modified. But it doesn't authenticate the client or prevent application-level attacks.
Layer 2: Input validation ensures that only well-formed requests reach your model. A malformed tensor-parallelism) shape gets rejected before it ever touches the GPU. But validation alone can't stop resource exhaustion attacks (you could send millions of valid requests).
Layer 3: Rate limiting ensures that users can't hammer your service. But it needs to be aware of resource costs - rate limiting by request count makes sense for simple APIs but misses the reality of ML inference where complexity varies wildly.
Layer 4: Monitoring and alerting ensure that attacks are detected quickly. If someone is trying to extract your model weights via side-channel attacks, monitoring latency variance might reveal it. Together, these layers create a system where attacks face multiple barriers and detection happens early.
Part 1: TLS Termination and Certificate Management
Why TLS Matters for ML APIs
Your model API transmits sensitive data: user queries, embeddings, predictions. You need encryption. But you also need certificate automation, because manual renewal at 3 AM is a recipe for incidents.
TLS is table-stakes for production systems. If you're not encrypting data in transit, attackers can intercept requests, read user data, and inject malicious inputs. Your model's security posture collapses without it.
Modern TLS also provides performance benefits. TLS 1.3 uses only one round-trip for connection establishment (vs. two for TLS 1.2), saving latency. Hardware acceleration for TLS operations means encryption/decryption is nearly free. The only cost is management overhead - which automation like cert-manager eliminates.
The modern approach uses cert-manager in Kubernetes-nvidia-kai-scheduler-gpu-job-scheduling)-ml-gpu-workloads) to automate certificate provisioning and renewal with Let's Encrypt (for public APIs) or your internal CA (for internal services).
The True Cost of Not Having Security
Before we implement anything, understand the real cost of security failures. A leaked API key that gives access to your model inference endpoint can cost you thousands in compute. An undetected adversarial attack that forces your model to make wrong predictions at scale can damage your reputation. A compromise that exposes user data can trigger GDPR fines of 4% of revenue or more. A DoS that knocks your model offline during critical business hours can cost you revenue directly.
The business case for security isn't theoretical - it's financial. The cost of implementing TLS, input validation, and rate limiting is measured in hours of engineering time and minimal operational overhead (automated certificate renewal, rate limiter infrastructure). The cost of not having these defenses is measured in actual money when something breaks. Most organizations would spend ten times as much on incident response and recovery as they would have spent on prevention.
The other angle is competitive. Customers increasingly care about security. If you're selling an ML service and your competitors can say "we have TLS, input validation, rate limiting" while you can't, you lose deals. Security becomes a feature, a selling point. Organizations that don't take it seriously get left behind.
Setting Up TLS with cert-manager
# cert-issuer.yaml: Configure Let's Encrypt for automatic renewal
apiVersion: cert-manager.io/v1
kind: ClusterIssuer
metadata:
name: letsencrypt-prod
spec:
acme:
server: https://acme-v02.api.letsencrypt.org/directory
email: ops@yourcompany.com
privateKeySecretRef:
name: letsencrypt-prod
solvers:
- http01:
ingress:
class: nginxOnce you define the issuer, your Ingress certificate is managed automatically:
# model-ingress.yaml: ML API exposed via HTTPS
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: ml-model-api
annotations:
cert-manager.io/cluster-issuer: "letsencrypt-prod"
spec:
ingressClassName: nginx
tls:
- hosts:
- api.models.yourcompany.com
secretName: ml-model-tls
rules:
- host: api.models.yourcompany.com
http:
paths:
- path: /predict
pathType: Prefix
backend:
service:
name: ml-inference
port:
number: 8000Cert-manager automatically provisions the certificate from Let's Encrypt, stores it in the ml-model-tls secret, and refreshes it before expiry. No manual intervention needed.
What's happening: You declare the desired TLS state in your Ingress manifest. Cert-manager watches for changes, creates a certificate request, solves the ACME challenge, and installs the cert. When renewal comes due (typically at 30 days before expiry), it repeats automatically.
Mutual TLS (mTLS) for Internal Service Communication
For internal service-to-service calls (model API to feature store, model API to analytics pipeline), TLS provides encryption but not client authentication. If someone compromises your internal network, they can impersonate any service.
mTLS fixes this: both client and server present certificates. The server verifies the client's identity, and vice versa. This is a powerful security control because now an attacker can't just call your model API from any compromised system - they need the specific client certificate, which they're unlikely to have unless they've also compromised the authorized client service. It dramatically raises the bar for attackers trying to use your model API from within your network.
The lifecycle of certificate management also improves with mTLS. Certificates need renewal, rotation, revocation. Manual management is painful. Istio's automatic certificate injection makes it transparent - certificates are generated, installed, and rotated automatically. Your applications never see the certificates; Istio handles it. You gain security without operational burden.
# model-service.yaml with mTLS via Istio
apiVersion: security.istio.io/v1beta1
kind: PeerAuthentication
metadata:
name: ml-model-api-mtls
namespace: ml-platform
spec:
mtls:
mode: STRICT # Require mTLS for all traffic
---
apiVersion: networking.istio.io/v1beta1
kind: DestinationRule
metadata:
name: ml-model-dr
namespace: ml-platform
spec:
host: ml-inference.ml-platform.svc.cluster.local
trafficPolicy:
tls:
mode: MUTUAL # Client must present cert
sni: ml-inference.ml-platform.svc.cluster.localWhen enabled, Istio automatically injects sidecars that:
- Generate certificates for each pod
- Mount certificates into workload containers
- Terminate and establish mTLS connections
- Rotate certificates before expiry
Your application code doesn't change - mTLS is transparent.
TLS 1.3 Configuration in Nginx/Envoy
TLS 1.2 is fine, but TLS 1.3 is faster (fewer round trips) and more secure (removed weak ciphers). Here's how to enforce it.
TLS version negotiation matters because older versions have known vulnerabilities. TLS 1.2 supports some weak ciphers that are only there for historical compatibility. TLS 1.3 removed all the weak ones, keeping only the strong ciphers. It also simplified the handshake, reducing latency by eliminating an extra round trip during connection setup. For ML APIs where latency matters, this is meaningful - every millisecond of reduction helps when you're trying to hit sub-second latency SLOs.
The challenge is that some old clients still use TLS 1.2. If you enforce TLS 1.3 exclusively, old clients break. The solution is to support both TLS 1.3 and 1.2 but prefer TLS 1.3. Modern clients upgrade, old clients still work. Over time, you migrate to TLS 1.3 only.
Performance also depends on your cipher choices. Some ciphers are computationally more expensive than others. Modern GPUs have dedicated hardware for AES-GCM encryption, making it nearly free. CHACHA20 is newer and might not have the same hardware support. Choose ciphers that balance security and performance for your deployment.
# nginx-tls-config.conf
server {
listen 443 ssl http2;
server_name api.models.yourcompany.com;
# TLS 1.3 and 1.2 only (no 1.1 or earlier)
ssl_protocols TLSv1.3 TLSv1.2;
ssl_ciphers 'TLS13-AES-256-GCM-SHA384:TLS13-AES-128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384';
ssl_prefer_server_ciphers on;
ssl_session_cache shared:SSL:10m;
ssl_session_timeout 10m;
# HSTS: Tell browsers to always use HTTPS
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always;
location /predict {
proxy_pass http://ml-inference:8000;
}
}For Envoy (often used as the sidecar in Istio):
# envoy-tls-config.yaml
listeners:
- name: listener_0
address:
socket_address:
address: 0.0.0.0
port_number: 443
filter_chains:
- transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext
common_tls_context:
tls_certificates:
- certificate_chain: { filename: "/etc/ssl/certs/cert.pem" }
private_key: { filename: "/etc/ssl/private/key.pem" }
min_protocol_version: TLSv1_3
cipher_suites:
- TLS_AES_256_GCM_SHA384
- TLS_AES_128_GCM_SHA256Key takeaway: TLS 1.3 reduces handshake latency for model APIs, which is critical when you're serving time-sensitive predictions.
Part 2: Input Validation for ML Models
Common Validation Failures in Production
Let me walk you through real failures I've seen in production ML systems. These aren't theoretical - they're problems that have cost organizations real money and downtime.
Validation failures often sneak into production because individual failures seem harmless. One bad request doesn't crash the system. But patterns of bad requests degrade service. And once patterns appear, they become attacks - either accidental (buggy client sending mal formed requests constantly) or intentional (attacker discovering that certain payloads cause problems).
One team deployed a recommendation model that expected embeddings in a specific range. The embedding pipeline had a subtle bug that occasionally produced embeddings with values of negative infinity. The model handled them gracefully (clipped to a valid range), so no one noticed. But the invalid embeddings meant the model wasn't making good recommendations. It took three weeks to find because the system "worked" - it just made bad predictions.
Another team had an NLP model that expected tokenized inputs of a specific length. When someone sent a payload with a malformed token sequence, the model's input validation layer didn't catch it. The malformed tokens propagated through the embedding layer, through the transformer, and produced nonsense predictions that downstream systems used. By the time the bad predictions affected users, they'd already been cached in a recommendation system, affecting thousands of users.
A third team had a computer vision API that accepted JPEG images. They didn't validate image dimensions. An attacker sent a 100,000x100,000 pixel image. Decoding it consumed all available memory, crashing the server. The attacker repeated every five minutes, keeping the service down for two hours.
These failures aren't glamorous security breaches - no one stole data. But they're expensive in terms of downtime, damage to trust, and time spent debugging. Proper input validation prevents all of them.
The benefit of input validation is that it's upstream. Validation catches problems before they hit your model, before they degrade inference, before they affect users. It's the easiest place to stop bad things - right at the gate, before resources are consumed.
The Validation Problem
Your model expects a tensor of shape (batch_size, 768) and dtype float32. But attackers don't care about your specifications. They send:
- Gigabyte-sized payloads to trigger out-of-memory (OOM) crashes
- Strings with SQL injection payloads to exploit LLM tokenizers
- NaN and infinity values to break numerical stability
- Wrong tensor shapes that cause silent production bugs
Input validation catches these before they hit your model.
Schema Validation with Pydantic v2
Pydantic validates Python objects against schemas. Your API defines request shapes, and Pydantic enforces them:
# model_api.py: Define your API contract with Pydantic v2
from pydantic import BaseModel, Field, validator
from typing import Optional
import numpy as np
class EmbeddingRequest(BaseModel):
"""Input contract for embedding API."""
text: str = Field(
...,
min_length=1,
max_length=10000, # Prevent DoS via huge strings
description="Input text to embed"
)
batch_id: Optional[str] = Field(
default=None,
max_length=50,
description="Optional batch identifier"
)
@validator('text')
def validate_text_safe(cls, v):
"""Prevent prompt injection in LLM prompts."""
# Block common injection patterns
dangerous_patterns = ['<script', 'javascript:', 'onerror=', 'onclick=']
v_lower = v.lower()
for pattern in dangerous_patterns:
if pattern in v_lower:
raise ValueError(f"Text contains potentially dangerous pattern: {pattern}")
return v
class TensorRequest(BaseModel):
"""Input contract for tensor-based models."""
features: list = Field(
...,
description="Input features as flat list"
)
batch_id: Optional[str] = None
@validator('features')
def validate_features(cls, v):
"""Check tensor shape, dtype, and value ranges."""
# Convert to numpy for analysis
arr = np.array(v, dtype=np.float32)
# Check shape
if arr.shape != (768,):
raise ValueError(f"Expected shape (768,), got {arr.shape}")
# Check for NaN, inf
if np.isnan(arr).any() or np.isinf(arr).any():
raise ValueError("Features contain NaN or infinity values")
# Check value ranges (adjust for your model)
if (np.abs(arr) > 100).any():
raise ValueError("Features out of expected range [-100, 100]")
return v.tolist()
class PredictionResponse(BaseModel):
"""Output contract."""
prediction: float
confidence: float
batch_id: Optional[str] = NoneNow your API enforces this:
# fastapi_endpoint.py
from fastapi import FastAPI, HTTPException, status
from fastapi.responses import JSONResponse
app = FastAPI()
@app.post("/embed", response_model=dict)
async def embed(request: EmbeddingRequest):
"""
Embed text. Pydantic automatically validates the request.
"""
try:
# By this point, request.text is guaranteed to be:
# - A string between 1 and 10,000 characters
# - Free of injection patterns
# - Valid JSON (parsed by Pydantic)
# Call your model
embedding = model.embed(request.text)
return {
"embedding": embedding.tolist(),
"batch_id": request.batch_id,
"status": "success"
}
except Exception as e:
# This shouldn't happen if validation passed
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": str(e)}
)
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: TensorRequest):
"""
Make predictions. Pydantic validates shape, dtype, and ranges.
"""
try:
arr = np.array(request.features, dtype=np.float32).reshape(1, -1)
prediction = model.predict(arr)
return PredictionResponse(
prediction=float(prediction[0]),
confidence=float(prediction[1]),
batch_id=request.batch_id
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)What's happening: Pydantic parses the incoming JSON, validates each field against your schema, runs custom validators, and either returns a validated object or raises a 422 error. Invalid requests never reach your model code.
Shape and Dtype Checks for Tensor Inputs
ML models are fragile. They assume specific tensor shapes and types. A shape mismatch causes silent bugs or crashes. Here's a robust validator:
# tensor_validators.py
import numpy as np
from typing import List, Tuple
from pydantic import validator, BaseModel, Field
class StrictTensorRequest(BaseModel):
"""Strict tensor validation for production models."""
values: List[float] = Field(..., description="Flat list of floats")
expected_shape: Tuple[int, ...] = Field(
default=(768,),
description="Expected shape after reshape"
)
expected_dtype: str = Field(default="float32", description="Expected dtype")
@validator('values')
def check_values(cls, v):
"""Validate value list before shape/dtype checks."""
if len(v) == 0:
raise ValueError("Values list cannot be empty")
if len(v) > 1_000_000:
raise ValueError("Values list exceeds maximum size (1M elements)")
return v
@validator('expected_dtype')
def check_dtype(cls, v):
"""Validate dtype is one we support."""
supported = ['float32', 'float64', 'int32', 'int64']
if v not in supported:
raise ValueError(f"Unsupported dtype: {v}. Supported: {supported}")
return v
def to_tensor(self) -> np.ndarray:
"""Convert and validate to tensor."""
# Convert to array with specified dtype
arr = np.array(self.values, dtype=self.expected_dtype)
# Reshape to expected shape
try:
arr = arr.reshape(self.expected_shape)
except ValueError as e:
raise ValueError(
f"Cannot reshape {len(self.values)} elements to shape {self.expected_shape}: {e}"
)
return arrUsage:
@app.post("/predict-strict")
async def predict_strict(request: StrictTensorRequest):
"""Predict with strict tensor validation."""
try:
tensor = request.to_tensor()
# Now we're guaranteed:
# - tensor.shape == request.expected_shape
# - tensor.dtype == request.expected_dtype
# - tensor.size <= 1,000,000
prediction = model.predict(tensor)
return {"prediction": prediction.tolist()}
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))String Length Limits for LLM APIs
LLM APIs are particularly vulnerable to input-based DoS: send a massive prompt, the tokenizer blows up, and your GPU OOMs.
# llm_validators.py
from pydantic import BaseModel, Field, validator
class ChatCompletionRequest(BaseModel):
"""LLM chat request with DoS-resistant limits."""
messages: list = Field(
...,
min_items=1,
max_items=50, # Limit conversation depth
description="Chat messages"
)
max_tokens: int = Field(
default=512,
le=4096, # Never allow >4K output tokens
description="Max tokens to generate"
)
system_prompt: str = Field(
default="",
max_length=5000,
description="System prompt (capped at 5K chars)"
)
@validator('messages')
def validate_messages(cls, v):
"""Validate message format and total token budget."""
total_chars = 0
for i, msg in enumerate(v):
if not isinstance(msg, dict):
raise ValueError(f"Message {i} must be a dict")
if 'content' not in msg:
raise ValueError(f"Message {i} missing 'content' field")
content = msg['content']
if not isinstance(content, str):
raise ValueError(f"Message {i} content must be string")
# Limit individual message size
if len(content) > 10000:
raise ValueError(f"Message {i} exceeds 10K character limit")
total_chars += len(content)
# Cap total conversation size
if total_chars > 100000:
raise ValueError("Total conversation exceeds 100K character limit")
return v
class CompletionResponse(BaseModel):
"""LLM response."""
content: str
finish_reason: str
tokens_used: intTest it:
# Test valid request
valid_req = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello, how are you?"}],
max_tokens=100
)
print(f"Valid request: {valid_req}")
# Test invalid request (too long)
try:
invalid_req = ChatCompletionRequest(
messages=[{"role": "user", "content": "x" * 20000}],
max_tokens=100
)
except ValueError as e:
print(f"Caught validation error: {e}")
# Output: Caught validation error: Message 0 exceeds 10K character limitExpected output:
Valid request: messages=[{'role': 'user', 'content': 'Hello, how are you?'}] max_tokens=100 system_prompt=''
Caught validation error: Message 0 exceeds 10K character limit
Injection Attack Prevention
LLM APIs are vulnerable to prompt injection: an attacker embeds malicious instructions in input that overrides your system prompt.
# injection_prevention.py
import re
from pydantic import validator, BaseModel
class SafeLLMRequest(BaseModel):
"""LLM request with injection prevention."""
user_input: str
@validator('user_input')
def prevent_injection(cls, v):
"""
Detect and block common prompt injection patterns.
This is a basic example; production systems need more sophisticated detection.
"""
# Pattern 1: Attempts to re-prompt the model
injection_patterns = [
r'ignore.*previous.*instructions?',
r'forget.*prompt',
r'system.*message',
r'you.*are.*actually',
r'follow.*new.*instructions?',
r'act.*as.*if',
]
v_lower = v.lower()
for pattern in injection_patterns:
if re.search(pattern, v_lower):
raise ValueError(f"Input contains suspected injection pattern: {pattern}")
# Pattern 2: Look for base64-encoded payloads (obscured injection)
if re.search(r'[A-Za-z0-9+/]{100,}={0,2}', v):
raise ValueError("Input may contain base64-encoded payload")
return vReal production systems use:
- Instruction tuning to make models resist injection
- Prompt templating that separates user input from instructions
- Separate verification models that detect injection attempts
- User intent classification before passing to the main model
Part 3: Rate Limiting Strategies
Why Generic Rate Limiting Fails for ML
A standard rate limiter (e.g., "100 requests per minute") treats all requests equally. But your ML API isn't equal: a /lightweight-classify request takes 10ms, while a /large-model-generate request takes 30 seconds and consumes 1 GPU.
Smart rate limiting accounts for compute cost.
Per-User and Per-API-Key Rate Limiting
Kong API Gateway makes this easy:
# kong-rate-limit.yaml: Configure per-user and per-key limits
apiVersion: configuration.konghq.com/v1
kind: KongPlugin
metadata:
name: ml-api-rate-limit
spec:
plugin: rate-limiting
config:
minute: 1000 # 1000 requests per minute (overall)
policy: redis # Use Redis for distributed counting
fault_tolerant: true # Fail open (allow) if Redis is down
redis_host: redis.redis.svc.cluster.local
redis_port: 6379
---
apiVersion: configuration.konghq.com/v1
kind: KongPlugin
metadata:
name: ml-api-per-key-limit
spec:
plugin: rate-limiting
config:
minute: 100 # 100 requests per minute per API key
header_name: "X-API-Key" # Use API key header for identity
policy: redis
redis_host: redis.redis.svc.cluster.local
redis_port: 6379
---
apiVersion: configuration.konghq.com/v1
kind: KongConsumer
metadata:
name: high-tier-user
spec:
username: high-tier-user
---
apiVersion: configuration.konghq.com/v1
kind: KongConsumerCredential
metadata:
name: high-tier-api-key
consumerRef: high-tier-user
spec:
type: key-auth
config:
key: sk_prod_abc123xyz789Now when requests come in:
# High-tier user (100 requests/min)
curl -H "X-API-Key: sk_prod_abc123xyz789" https://api.models.yourcompany.com/predict
# Regular user (automatic rate limit applies)
curl https://api.models.yourcompany.com/predictKong's Redis backend counts requests per key and returns 429 Too Many Requests when limits are exceeded.
Token Bucket vs Sliding Window Algorithms
Token Bucket: Clients get N tokens per time window. Each request costs 1 token. Allows brief bursts.
# token_bucket.py: Simple token bucket rate limiter
import time
from typing import Dict
class TokenBucket:
"""Token bucket rate limiter."""
def __init__(self, capacity: int, refill_rate: float):
"""
Args:
capacity: Max tokens in bucket
refill_rate: Tokens added per second
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = float(capacity)
self.last_refill = time.time()
def allow_request(self, tokens_needed: int = 1) -> bool:
"""Check if request is allowed, consume tokens if yes."""
now = time.time()
elapsed = now - self.last_refill
# Refill tokens based on elapsed time
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.refill_rate
)
self.last_refill = now
# Check if enough tokens
if self.tokens >= tokens_needed:
self.tokens -= tokens_needed
return True
return False
# Usage: 100 requests per 60 seconds (refill_rate = 100/60 = 1.67 tokens/sec)
limiter = TokenBucket(capacity=100, refill_rate=100/60)
# Simulate requests
for i in range(150):
allowed = limiter.allow_request()
if not allowed:
print(f"Request {i}: RATE LIMITED")
else:
print(f"Request {i}: ALLOWED")Sliding Window: Count requests in the last N seconds. Simpler but less burst-friendly.
# sliding_window.py: Sliding window rate limiter
import time
from collections import deque
class SlidingWindowLimiter:
"""Sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int):
"""
Args:
max_requests: Max requests in window
window_seconds: Window size in seconds
"""
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = deque() # Timestamps of requests
def allow_request(self) -> bool:
"""Check if request is allowed."""
now = time.time()
# Remove old requests outside the window
while self.requests and self.requests[0] < now - self.window_seconds:
self.requests.popleft()
# Check if under limit
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
# Usage: 100 requests per 60 seconds
limiter = SlidingWindowLimiter(max_requests=100, window_seconds=60)
# Simulate requests at varying rates
import time
for i in range(120):
allowed = limiter.allow_request()
print(f"Request {i}: {'ALLOWED' if allowed else 'RATE LIMITED'}")
if i == 50:
time.sleep(10) # Simulate pauseToken Bucket is better for ML: It allows brief bursts (useful for load spikes) while enforcing long-term limits. Sliding Window is stricter and better for preventing abuse patterns.
GPU-Aware Rate Limits Based on Compute Cost
Different endpoints consume different compute. A token bucket with variable costs:
# compute_aware_limiter.py: Rate limit based on GPU cost
from dataclasses import dataclass
from enum import Enum
import time
class EndpointCost(Enum):
"""GPU compute cost per request."""
LIGHTWEIGHT = 1 # Classify: ~10ms, low GPU usage
STANDARD = 5 # Embed: ~100ms, moderate GPU usage
HEAVY = 50 # Generate: ~3s, high GPU usage
VERY_HEAVY = 100 # Finetune step: ~10s, very high GPU usage
@dataclass
class RateLimitConfig:
"""Config for GPU-aware rate limiting."""
total_budget: int # Total "cost budget" per minute (e.g., 10,000)
refill_rate: float # Budget refilled per second
class GPUAwareRateLimiter:
"""Rate limit based on compute cost."""
def __init__(self, config: RateLimitConfig):
self.total_budget = config.total_budget
self.refill_rate = config.refill_rate
self.available_budget = float(config.total_budget)
self.last_refill = time.time()
def allow_request(self, cost: EndpointCost) -> bool:
"""Check if request is allowed given its cost."""
now = time.time()
elapsed = now - self.last_refill
# Refill budget
self.available_budget = min(
self.total_budget,
self.available_budget + elapsed * self.refill_rate
)
self.last_refill = now
# Check if enough budget
cost_value = cost.value
if self.available_budget >= cost_value:
self.available_budget -= cost_value
return True
return False
def budget_remaining(self) -> float:
"""Get remaining budget."""
now = time.time()
elapsed = now - self.last_refill
return min(
self.total_budget,
self.available_budget + elapsed * self.refill_rate
)
# Configure: 10,000 cost units per minute = 166.67 per second
config = RateLimitConfig(total_budget=10000, refill_rate=10000/60)
limiter = GPUAwareRateLimiter(config)
# Simulate requests
requests = [
("lightweight-classify", EndpointCost.LIGHTWEIGHT),
("lightweight-classify", EndpointCost.LIGHTWEIGHT),
("embed", EndpointCost.STANDARD),
("generate", EndpointCost.HEAVY),
("generate", EndpointCost.HEAVY),
("lightweight-classify", EndpointCost.LIGHTWEIGHT),
]
for endpoint, cost in requests:
allowed = limiter.allow_request(cost)
remaining = limiter.budget_remaining()
status = "ALLOWED" if allowed else "RATE LIMITED"
print(f"{endpoint} (cost={cost.value}): {status}, budget_remaining={remaining:.1f}")Expected output:
lightweight-classify (cost=1): ALLOWED, budget_remaining=10008.6
lightweight-classify (cost=1): ALLOWED, budget_remaining=10008.6
embed (cost=5): ALLOWED, budget_remaining=10008.6
generate (cost=50): ALLOWED, budget_remaining=9958.6
generate (cost=50): ALLOWED, budget_remaining=9908.6
lightweight-classify (cost=1): ALLOWED, budget_remaining=9907.6
Part 4: WAF Rules for ML Endpoints
WAF Basics
A Web Application Firewall (WAF) inspects HTTP requests and blocks malicious patterns before they reach your application. For ML APIs, you need custom rules beyond)) the standard OWASP ruleset.
OWASP Core Rule Set Customization
ModSecurity is the standard WAF. Here's a basic setup with custom ML-specific rules:
# waf-config.yaml: ModSecurity configuration
apiVersion: v1
kind: ConfigMap
metadata:
name: modsecurity-config
data:
modsecurity.conf: |
SecRuleEngine On
SecRequestBodyLimit 10485760 # 10MB limit
SecRequestBodyNoAction Off
# Include OWASP CRS (Core Rule Set)
Include /etc/modsecurity/rules/crs-setup.conf
Include /etc/modsecurity/rules/rules.conf
# Custom ML API rules
# Rule 1: Block requests larger than 50MB (prevent OOM)
SecRule REQUEST_HEADERS:Content-Length "@gt 52428800" \
"id:1000,phase:2,deny,status:413,msg:'Payload too large for ML API'"
# Rule 2: Block requests to paths that don't exist
SecRule REQUEST_URI "!@rx ^/(predict|embed|generate|health|metrics)" \
"id:1001,phase:2,deny,status:404,msg:'Invalid endpoint'"
# Rule 3: Rate limit by IP (100 requests per 10 seconds)
SecRule IP:@collection "!@inspectFile /tmp/ip-ratelimit.json" \
"id:1002,phase:1,nolog,pass"
# Rule 4: Block payloads with suspicious base64 patterns
SecRule ARGS "@rx ^[A-Za-z0-9+/]{10000,}" \
"id:1003,phase:2,deny,status:400,msg:'Suspiciously large base64 payload'"
# Rule 5: Enforce JSON Content-Type for POST /predict
SecRule REQUEST_URI "@eq /predict" \
"chain,id:1004,phase:1,deny,status:400,msg:'POST to /predict requires JSON'"
SecRule REQUEST_HEADERS:Content-Type "!@contains application/json" \
"nolog"Deploy with Nginx:
# nginx-waf-ingress.yaml
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: ml-api-waf
annotations:
nginx.ingress.kubernetes.io/enable-modsecurity: "true"
nginx.ingress.kubernetes.io/enable-owasp-core-rules: "true"
nginx.ingress.kubernetes.io/modsecurity-snippet: |
SecRuleEngine On
SecRequestBodyLimit 52428800
spec:
ingressClassName: nginx
rules:
- host: api.models.yourcompany.com
http:
paths:
- path: /predict
pathType: Prefix
backend:
service:
name: ml-inference
port:
number: 8000Blocking Oversized Payloads
OOM attacks send huge payloads to crash your inference server-inference-server-multi-model-serving). WAF rules catch this at the edge:
# modsec-oom-protection.rules
# Block POST requests with body >50MB
SecRule ARGS "@gt 52428800" \
"id:2000,phase:1,deny,status:413,msg:'Payload exceeds 50MB limit'"
# Block requests claiming massively large bodies (Content-Length header)
SecRule REQUEST_HEADERS:Content-Length "@gt 52428800" \
"id:2001,phase:1,deny,status:413,msg:'Content-Length exceeds 50MB'"
# Block JSON payloads with deep nesting (can cause parser OOM)
SecRule ARGS:json "@rx ^[\s{]+{[\s{]+{[\s{]+{[\s{]+{[\s{]+{[\s{]+{[\s{]+{" \
"id:2002,phase:2,deny,status:400,msg:'Deeply nested JSON detected'"
# Block requests claiming unreasonable array sizes
SecRule ARGS:array_length "@gt 1000000" \
"id:2003,phase:2,deny,status:400,msg:'Array size exceeds 1M elements'"CloudFront and Cloudflare WAF for Public APIs
For public-facing ML APIs, use managed WAF services:
# terraform-cloudfront-waf.tf
resource "aws_cloudfront_distribution" "ml_api" {
origin {
domain_name = "api.models.yourcompany.com"
origin_id = "ml_api_origin"
}
enabled = true
default_cache_behavior {
allowed_methods = ["GET", "HEAD", "POST", "PUT", "DELETE"]
cached_methods = ["GET", "HEAD"]
target_origin_id = "ml_api_origin"
forwarded_values {
query_string = true
headers = ["Authorization", "X-API-Key"]
cookies {
forward = "all"
}
}
viewer_protocol_policy = "redirect-to-https"
min_ttl = 0
default_ttl = 0
max_ttl = 0
}
# Attach AWS WAF
web_acl_id = aws_wafv2_web_acl.ml_api_waf.arn
restrictions {
geo_restriction {
restriction_type = "none"
}
}
viewer_certificate {
cloudfront_default_certificate = true
}
}
resource "aws_wafv2_web_acl" "ml_api_waf" {
name = "ml-api-waf"
scope = "CLOUDFRONT"
default_action {
allow {}
}
rule {
name = "RateLimitRule"
priority = 0
action {
block {}
}
statement {
rate_based_statement {
limit = 2000
aggregate_key_type = "IP"
}
}
visibility_config {
cloudwatch_metrics_enabled = true
metric_name = "RateLimitRule"
sampled_requests_enabled = true
}
}
rule {
name = "AWSManagedRulesCommonRuleSet"
priority = 1
override_action {
none {}
}
statement {
managed_rule_group_statement {
name = "AWSManagedRulesCommonRuleSet"
vendor_name = "AWS"
}
}
visibility_config {
cloudwatch_metrics_enabled = true
metric_name = "AWSManagedRulesCommonRuleSet"
sampled_requests_enabled = true
}
}
visibility_config {
cloudwatch_metrics_enabled = true
metric_name = "ml-api-waf"
sampled_requests_enabled = true
}
}Part 5: DoS Protection and Circuit Breakers
Request Queuing with Backpressure
When demand exceeds capacity, queue requests and reject excess:
# backpressure_handler.py: Queue with backpressure
import asyncio
from collections import deque
from dataclasses import dataclass
from typing import Optional
@dataclass
class QueuedRequest:
"""A queued request."""
request_id: str
payload: dict
timestamp: float
class RequestQueue:
"""Queue with backpressure and timeout."""
def __init__(self, max_queue_size: int = 1000, timeout_sec: float = 30.0):
self.max_queue_size = max_queue_size
self.timeout_sec = timeout_sec
self.queue = deque()
self.processing = False
def enqueue(self, request_id: str, payload: dict) -> tuple[bool, Optional[str]]:
"""
Try to queue a request.
Returns: (success, error_message)
"""
import time
if len(self.queue) >= self.max_queue_size:
return False, f"Queue full ({self.max_queue_size} items). Try again later."
self.queue.append(QueuedRequest(
request_id=request_id,
payload=payload,
timestamp=time.time()
))
return True, None
def dequeue(self) -> Optional[QueuedRequest]:
"""Get next request, removing old ones."""
import time
now = time.time()
# Remove timed-out requests
while self.queue and self.queue[0].timestamp < now - self.timeout_sec:
self.queue.popleft()
if self.queue:
return self.queue.popleft()
return None
def queue_size(self) -> int:
return len(self.queue)
# FastAPI integration
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
app = FastAPI()
request_queue = RequestQueue(max_queue_size=500)
@app.post("/predict-queued")
async def predict_queued(request_id: str, payload: dict):
"""Predict with backpressure queuing."""
success, error = request_queue.enqueue(request_id, payload)
if not success:
return JSONResponse(
status_code=429,
content={"error": error, "retry_after": 5}
)
# Process from queue
queued_req = request_queue.dequeue()
if not queued_req:
return JSONResponse(
status_code=503,
content={"error": "Processing unavailable"}
)
# Call model
prediction = model.predict(queued_req.payload)
return {"prediction": prediction, "request_id": queued_req.request_id}Circuit Breakers with Istio/Envoy
A circuit breaker stops sending requests to a failing service, letting it recover:
# istio-circuit-breaker.yaml
apiVersion: networking.istio.io/v1beta1
kind: DestinationRule
metadata:
name: ml-model-circuit-breaker
spec:
host: ml-inference.ml-platform.svc.cluster.local
trafficPolicy:
connectionPool:
tcp:
maxConnections: 100
http:
http1MaxPendingRequests: 100
maxRequestsPerConnection: 2
h2UpgradePolicy: UPGRADE
outlierDetection:
consecutive5xxErrors: 5 # Trip after 5 consecutive errors
interval: 30s
baseEjectionTime: 30s
maxEjectionPercent: 50 # At most 50% of instances ejected
minRequestVolume: 5 # Require at least 5 requests before detection
splitExternalLocalOriginErrors: trueWhen deployed:
- Envoy tracks error rates for each backend instance
- If an instance has 5 consecutive errors, it's ejected (not sent traffic)
- After 30 seconds, Envoy tries to bring it back
- If more than 50% of instances are ejected, none are ejected (fail-safe)
Abuse Detection Using Request Pattern Analysis
Detect and block attackers based on request patterns:
# abuse_detector.py: Pattern-based abuse detection
from collections import defaultdict
from dataclasses import dataclass
import time
@dataclass
class RequestPattern:
"""Pattern of requests from an IP."""
request_count: int
error_count: int
last_request_time: float
payload_sizes: list[int]
class AbuseDetector:
"""Detect abusive request patterns."""
def __init__(self, window_seconds: int = 60):
self.window_seconds = window_seconds
self.patterns = defaultdict(RequestPattern)
def record_request(
self,
client_ip: str,
success: bool,
payload_size: int
) -> bool:
"""
Record a request. Return True if request is OK, False if abuse detected.
"""
now = time.time()
pattern = self.patterns[client_ip]
# Initialize if new
if pattern.request_count == 0:
pattern.request_count = 0
pattern.error_count = 0
pattern.last_request_time = now
pattern.payload_sizes = []
# Update counts
pattern.request_count += 1
if not success:
pattern.error_count += 1
pattern.payload_sizes.append(payload_size)
pattern.last_request_time = now
# Check for abuse patterns
# Pattern 1: Too many requests per second
if pattern.request_count > 100:
return False, "Rate limit exceeded"
# Pattern 2: High error rate (>50% errors)
if pattern.request_count >= 10:
error_rate = pattern.error_count / pattern.request_count
if error_rate > 0.5:
return False, "High error rate detected"
# Pattern 3: Payload sizes increasing (scanning for limits)
if len(pattern.payload_sizes) >= 5:
recent_sizes = pattern.payload_sizes[-5:]
if all(recent_sizes[i] < recent_sizes[i+1] for i in range(4)):
return False, "Suspicious payload size escalation"
# Pattern 4: Requests at regular intervals (scanning tool)
if pattern.request_count >= 3:
intervals = [
pattern.payload_sizes[i+1] - pattern.payload_sizes[i]
for i in range(len(pattern.payload_sizes) - 1)
]
if intervals and max(intervals) - min(intervals) < 10: # Very regular
return False, "Suspicious regular request pattern"
return True, None
# FastAPI middleware
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
app = FastAPI()
detector = AbuseDetector()
class AbuseDetectionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
# Read body to get size
body = await request.body()
# Call model (for now, assume success)
is_ok, error = detector.record_request(
client_ip=client_ip,
success=True,
payload_size=len(body)
)
if not is_ok:
return JSONResponse(
status_code=429,
content={"error": f"Abuse detected: {error}"}
)
response = await call_next(request)
return response
app.add_middleware(AbuseDetectionMiddleware)
@app.post("/predict")
async def predict(request_data: dict):
return {"prediction": 0.5}Test the abuse detector:
detector = AbuseDetector()
# Simulate normal requests
print("=== Normal requests ===")
for i in range(5):
ok, _ = detector.record_request(
client_ip="192.168.1.1",
success=True,
payload_size=1024
)
print(f"Request {i}: {'OK' if ok else 'BLOCKED'}")
# Simulate scanning attack (escalating payload sizes)
print("\n=== Scanning attack (escalating payloads) ===")
for i in range(6):
ok, error = detector.record_request(
client_ip="192.168.1.2",
success=True,
payload_size=1000 + i * 5000
)
print(f"Request {i}: {'OK' if ok else 'BLOCKED'} {error}")
# Simulate error-rate attack
print("\n=== Error-rate attack ===")
for i in range(15):
ok, error = detector.record_request(
client_ip="192.168.1.3",
success=(i % 2 == 0), # 50% error rate
payload_size=1024
)
print(f"Request {i}: {'OK' if ok else 'BLOCKED'} {error}")Expected output:
=== Normal requests ===
Request 0: OK
Request 1: OK
Request 2: OK
Request 3: OK
Request 4: OK
=== Scanning attack (escalating payloads) ===
Request 0: OK None
Request 1: OK None
Request 2: OK None
Request 3: OK None
Request 4: OK None
Request 5: BLOCKED Suspicious payload size escalation
=== Error-rate attack ===
Request 0: OK None
Request 1: OK None
...
Request 10: BLOCKED High error rate detected
Architecture Diagram
Here's how all these layers work together:
graph TB
Client["Client / User"]
CloudFront["CloudFront / Cloudflare WAF"]
WAF["ModSecurity WAF<br/>- OOM protection<br/>- Injection blocking<br/>- Oversized payload blocking"]
Kong["Kong API Gateway<br/>- Per-key rate limiting<br/>- GPU-aware cost budgets<br/>- Request queuing"]
Nginx["Nginx Ingress<br/>- TLS 1.3 termination<br/>- HSTS headers<br/>- Connection pooling"]
Istio["Istio Service Mesh<br/>- mTLS<br/>- Circuit breaker<br/>- Outlier detection"]
AppServer["FastAPI Server<br/>- Pydantic v2 validation<br/>- Tensor shape/dtype checks<br/>- String length limits"]
Model["ML Model<br/>- GPU inference<br/>- Response generation"]
Metrics["Prometheus<br/>- Request latency<br/>- Error rates<br/>- Queue depth<br/>- GPU utilization"]
Client -->|HTTPS Request| CloudFront
CloudFront -->|Filter malicious| WAF
WAF -->|Valid traffic| Kong
Kong -->|Check rate limits<br/>Queue if needed| Nginx
Nginx -->|Establish mTLS| Istio
Istio -->|Route + monitor| AppServer
AppServer -->|Validate input<br/>Prevent injection| Model
Model -->|Prediction| AppServer
AppServer -->|Response| Istio
Istio -->|Monitor| Metrics
Istio -->|Circuit break if<br/>error rate high| AppServerPutting It All Together: Complete Example
Here's a production-ready ML serving setup:
# complete_secure_api.py
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, validator
import numpy as np
import time
from typing import Optional
# ============================================================================
# 1. VALIDATION LAYER (Pydantic)
# ============================================================================
class PredictRequest(BaseModel):
"""Validated prediction request."""
features: list[float] = Field(..., description="Input features")
request_id: Optional[str] = Field(default=None, max_length=50)
@validator('features')
def validate_features(cls, v):
if len(v) != 10:
raise ValueError("Expected 10 features")
arr = np.array(v, dtype=np.float32)
if np.isnan(arr).any() or np.isinf(arr).any():
raise ValueError("Features contain NaN or inf")
return v
class PredictResponse(BaseModel):
"""Validated response."""
prediction: float
confidence: float
request_id: Optional[str] = None
# ============================================================================
# 2. RATE LIMITING LAYER
# ============================================================================
class GPUAwareRateLimiter:
"""Rate limit by compute cost."""
def __init__(self, budget_per_minute: int = 10000):
self.budget_per_minute = budget_per_minute
self.available_budget = float(budget_per_minute)
self.last_refill = time.time()
def allow_request(self, cost: int) -> bool:
now = time.time()
elapsed = now - self.last_refill
self.available_budget = min(
self.budget_per_minute,
self.available_budget + elapsed * (self.budget_per_minute / 60)
)
self.last_refill = now
if self.available_budget >= cost:
self.available_budget -= cost
return True
return False
limiter = GPUAwareRateLimiter(budget_per_minute=10000)
# ============================================================================
# 3. ABUSE DETECTION LAYER
# ============================================================================
from collections import defaultdict
class AbuseDetector:
"""Detect abusive patterns."""
def __init__(self):
self.patterns = defaultdict(lambda: {
'count': 0,
'errors': 0,
'sizes': []
})
def check(self, ip: str, success: bool, payload_size: int) -> bool:
pattern = self.patterns[ip]
pattern['count'] += 1
if not success:
pattern['errors'] += 1
pattern['sizes'].append(payload_size)
# Block if too many requests
if pattern['count'] > 100:
return False
# Block if high error rate
if pattern['count'] >= 10 and pattern['errors'] / pattern['count'] > 0.5:
return False
return True
detector = AbuseDetector()
# ============================================================================
# 4. APP SETUP
# ============================================================================
app = FastAPI(title="Secure ML API")
# Mock model (replace with your actual model)
def predict_model(features: np.ndarray) -> tuple[float, float]:
"""Simple mock model."""
prediction = float(np.mean(features))
confidence = float(np.random.random())
return prediction, confidence
# ============================================================================
# 5. ENDPOINTS WITH LAYERED SECURITY
# ============================================================================
@app.post("/predict", response_model=PredictResponse)
async def predict(request_data: PredictRequest, request: Request):
"""
Secure prediction endpoint with:
- TLS encryption (handled by Nginx/cert-manager)
- Input validation (Pydantic)
- Rate limiting (GPU-aware)
- Abuse detection (pattern analysis)
- mTLS (handled by Istio)
"""
client_ip = request.client.host
# 1. Abuse detection
if not detector.check(client_ip, True, len(str(request_data.features))):
raise HTTPException(status_code=429, detail="Abuse detected")
# 2. Rate limiting (cost = 5 for standard prediction)
if not limiter.allow_request(cost=5):
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# 3. Input validation already done by Pydantic
# 4. Call model
features_arr = np.array(request_data.features, dtype=np.float32).reshape(1, -1)
prediction, confidence = predict_model(features_arr)
# 5. Return validated response
return PredictResponse(
prediction=prediction,
confidence=confidence,
request_id=request_data.request_id
)
@app.get("/health")
async def health():
"""Health check endpoint (lightweight, cost=0)."""
return {"status": "ok"}
# ============================================================================
# 6. ERROR HANDLERS
# ============================================================================
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.detail}
)
# ============================================================================
# 7. RUN
# ============================================================================
if __name__ == "__main__":
import uvicorn
# In production, use Gunicorn + Uvicorn workers
# gunicorn complete_secure_api:app --workers 4 --worker-class uvicorn.workers.UvicornWorker
print("Starting Secure ML API...")
print("TLS: cert-manager (automatic)")
print("Input validation: Pydantic v2")
print("Rate limiting: GPU-aware token bucket")
print("Abuse detection: Pattern analysis")
print("mTLS: Istio sidecar injection")
uvicorn.run(app, host="0.0.0.0", port=8000)Test it:
# Install dependencies
pip install fastapi uvicorn pydantic numpy
# Run the API
python complete_secure_api.py
# In another terminal:
# Test 1: Valid request
curl -X POST http://localhost:8000/predict \
-H "Content-Type: application/json" \
-d '{"features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], "request_id": "req_001"}'
# Test 2: Invalid features (wrong shape)
curl -X POST http://localhost:8000/predict \
-H "Content-Type: application/json" \
-d '{"features": [1.0, 2.0]}'
# Test 3: Health check
curl http://localhost:8000/healthExpected output:
{
"prediction": 5.5,
"confidence": 0.723,
"request_id": "req_001"
}
{
"error": "1 validation error for PredictRequest\nfeatures\n Value error, Expected 10 features [type=value_error, input_value=[1.0, 2.0], input_type=list]"
}
{
"status": "ok"
}Summary: The Security Stack
You now have a comprehensive security strategy for ML serving:
- TLS/mTLS: Automatic certificates via cert-manager, mutual TLS for internal services, TLS 1.3 for performance
- Input Validation: Pydantic schemas enforce shape/dtype/length, custom validators block injection patterns
- Rate Limiting: GPU-aware cost budgets that account for inference complexity, per-key tracking
- WAF: OWASP rules plus custom ML-specific rules for OOM protection and payload size limits
- DoS Protection: Backpressure queuing, circuit breakers, abuse detection via request patterns
Stack these layers, and your ML API becomes resilient against the threats that hit production systems daily.