January 30, 2026
Python MLOps ONNX PyTorch

Model Serialization: Pickle, ONNX, and TorchScript

You've trained the perfect model. It generalizes beautifully, your validation metrics are stellar, and your boss is already dreaming about deployment. But there's a problem: how do you actually save it? And more importantly, how do you load it somewhere else, especially if you want blazing-fast inference?

That's where model serialization comes in. It's unsexy compared to building architectures or tuning hyperparameters, but it's absolutely critical. You can't serve a model in production if you can't serialize it. And you can't get fast inference if you don't know which serialization format fits your use case.

Most tutorials skip past serialization like it's a footnote. They show you one pickle.dump() call, wave their hands, and move on. That works in a Jupyter notebook. It fails spectacularly in production. The version you trained on vanishes into Python's environment chaos, the model you serialized runs fine on your laptop and silently produces wrong predictions on the server, and your security team has a small aneurysm when they find out you're loading pickle files from an external registry.

We're going to fix all of that. In this article, we cover the full serialization toolkit from the ground up: why each format exists, what trade-offs it makes, and exactly when to reach for it. We'll walk through Pickle and joblib for scikit-learn workflows, move into PyTorch's state_dict approach, compile models with TorchScript, and then explore ONNX, the format that lets you train in PyTorch and serve anywhere, with up to 5x faster CPU inference as a bonus. We'll also look at compression techniques like quantization and pruning, run real benchmarks to back up the claims, and finish with a production-grade model registry you can actually use.

By the end, you'll know how to serialize responsibly: with checksums, version metadata, output parity verification, and a clear decision tree for choosing the right format. Whether you're deploying a scikit-learn classifier to a REST API or shipping a ResNet to an edge device, the patterns here translate directly. Let's get into it.

In this article, we're exploring the full toolkit: Pickle (and joblib) for scikit-learn, PyTorch's state_dict and TorchScript, ONNX for framework-agnostic portability, and ONNX Runtime for CPU-bound performance. We'll benchmark them head-to-head and show you exactly what to reach for in different scenarios.


Table of Contents
  1. Why Serialization Matters
  2. Pickle and Joblib: The Scikit-Learn Approach
  3. The Naive Approach (And Why It Works... Until It Doesn't)
  4. The Security Trap
  5. The Better Path: Joblib
  6. Version Pinning: Your Lifeline
  7. The Best Practices Checklist
  8. PyTorch: state_dict vs Full Model Serialization
  9. Option 1: Save the Entire Model
  10. Option 2: Save Only Weights (state_dict), Recommended
  11. The Production Pattern
  12. TorchScript: Compiling Python to Production-Ready Code
  13. Tracing: The Easy Path
  14. Scripting: The Comprehensive Path
  15. Serialization Size and Speed
  16. ONNX: The Universal Language for Models
  17. Converting PyTorch to ONNX
  18. Loading and Running ONNX Models with ONNX Runtime
  19. ONNX Runtime CPU vs PyTorch
  20. Validating Output Parity
  21. Pickle vs ONNX vs TorchScript: Choosing the Right Tool
  22. Model Compression: Quantization and Pruning
  23. Quantization: Shrink Weights Numerically
  24. Pruning: Remove Unimportant Weights
  25. Cross-Platform Deployment
  26. Common Serialization Mistakes
  27. The Ultimate Benchmark: PyTorch vs TorchScript vs ONNX Runtime
  28. Putting It All Together: A Production Serialization Pipeline
  29. Choosing Your Format: A Decision Tree
  30. Summary

Why Serialization Matters

Before we dive into the how, let's spend some real time on the why. Serialization is the bridge between a model that exists in memory during a training run and one that can be deployed, shared, audited, and served at scale. Skip this foundation and you'll eventually build on sand.

Security is the first concern most engineers underestimate. Pickle files are not inert data blobs, they are Python bytecode, and loading one executes code. A malicious pickle can wipe your filesystem, exfiltrate credentials, or install backdoors. If your pipeline ever loads a model from an untrusted source (a public S3 bucket, a GitHub release, user uploads), you have a critical vulnerability if you're using raw pickle. This alone is reason enough to learn safer alternatives.

Portability is the second surprise. A PyTorch model saved on your development machine embeds assumptions about your Python version, your library versions, and occasionally even your operating system's numerical libraries. Load that same file six months later after an upgrade cycle, and you may get a silent compatibility shim that produces slightly wrong predictions, or a hard crash with an opaque error message. ONNX was specifically designed to eliminate this category of failure.

Performance is the third motivation, and it's quantifiable. A full PyTorch model includes all the training infrastructure: gradient tracking, Python dispatch overhead, the entire class hierarchy. At inference time you don't need any of that. TorchScript and ONNX strip the inference computation down to its essence, removing Python overhead from the hot path. The result is measurable: on CPU, ONNX Runtime consistently delivers 3-5x faster inference compared to standard PyTorch, which matters enormously when you're paying per CPU cycle.

Framework independence compounds the portability argument. The ML ecosystem is fragmented by design, and that's not going away. Your model trained in PyTorch today may need to run in TensorFlow Serving tomorrow, in CoreML for an iOS app next month, or in TensorRT on an NVIDIA accelerator after that. ONNX is the universal intermediate representation that makes those handoffs possible without retraining. Think of it as the assembly language of neural networks.

Dependency tracking forces the rigor that protects you. When you joblib-serialize a scikit-learn model trained on sklearn 1.0, you must record that version. Can you load it with sklearn 1.5? Sometimes yes, sometimes silently no. ONNX makes dependencies explicit at the format level, which turns what was a runtime surprise into a compile-time contract. These aren't edge cases, they're daily realities in any team running ML in production.


Pickle and Joblib: The Scikit-Learn Approach

Let's start with the most common path: you've trained a scikit-learn model (RandomForest, SVM, whatever), and you need to save it.

The Naive Approach (And Why It Works... Until It Doesn't)

The simplest possible serialization in Python is a single pickle.dump() call. You've probably seen this pattern in tutorials, and there's a reason it's everywhere, it genuinely works for local development and throwaway experiments. The problem is that "works for experiments" and "safe for production" are two different bars, and pickle only clears the first one. Here's the naive approach so you know exactly what you're dealing with:

python
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
 
# Train a model
X, y = load_iris(return_X_y=True)
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X, y)
 
# Save it (simple!)
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)
 
# Load it (also simple!)
with open('model.pkl', 'rb') as f:
    model = pickle.load(f)
 
# Make predictions
predictions = model.predict(X[:5])
print(predictions)

This works. It's convenient. And it's exactly how 70% of production models are saved. But there are hidden costs.

Those hidden costs compound over time. The first six months everything looks fine. Then you upgrade sklearn for another project, your model starts behaving oddly, and you spend two days debugging a numerical discrepancy that was actually a silent API change in the library. The pattern above has no guard rails against any of this.

The Security Trap

Pickle is Python bytecode. When you unpickle, Python executes code. If you load a pickle from an untrusted source, say, a model from GitHub or a user upload, you're running arbitrary code.

python
# Never do this with untrusted input!
malicious_pickle = """
cos
system
(S'rm -rf /'
tR.
"""
 
# If you unpickle this, your filesystem burns.

The fix? Never unpickle untrusted data. And if you must load models from external sources, use ONNX or a text-based format instead.

The security issue here is not theoretical, it has been demonstrated in multiple real-world incidents involving ML model sharing platforms. If your team runs any kind of model hub or allows external model uploads, this is a live vulnerability worth taking seriously. The safe path is to define a policy: only serialize with joblib or ONNX internally, and never allow unpickling of externally sourced files.

The Better Path: Joblib

Scikit-learn's preferred serializer is joblib. It's pickle-like but faster and more efficient for large numpy arrays. Joblib was designed specifically for scientific Python workloads, it understands the memory layout of large arrays and can serialize them in a fraction of the time pickle requires, especially for models with large internal data structures like Random Forests.

python
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
 
X, y = load_iris(return_X_y=True)
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X, y)
 
# Save with compression (smaller file, slight overhead)
joblib.dump(model, 'model.joblib', compress=3)
 
# Load
model = joblib.load('model.joblib')
predictions = model.predict(X[:5])

Joblib compresses tree structures efficiently. A RandomForest with 100 trees might drop from 5 MB (pickle) to 1 MB (joblib with compression=3).

The compress parameter accepts values from 0 (no compression) to 9 (maximum compression). In practice, compress=3 is the sweet spot, it cuts file size dramatically with minimal CPU overhead during save and load. For very large models being served in latency-sensitive contexts, you might drop to compress=1 to keep load times down.

Version Pinning: Your Lifeline

The real danger with pickle/joblib: dependency drift.

python
# Saved with sklearn 1.0
model = joblib.load('model.joblib')
 
# If you're on sklearn 1.5 and the API changed...
# you might get silent errors or wrong predictions.

The solution is obsessive version pinning:

scikit-learn==1.0.2
joblib==1.2.0
numpy==1.23.5

Every time you load a pickled model, check versions:

python
import joblib
import sklearn
 
model = joblib.load('model.joblib')
 
# In production, verify versions match training
required_sklearn = "1.0.2"
if sklearn.__version__ != required_sklearn:
    raise RuntimeError(
        f"Model requires sklearn {required_sklearn}, "
        f"but you have {sklearn.__version__}"
    )

This version check pattern looks overly cautious until the day it catches a real discrepancy and saves you from shipping broken predictions. Make it a habit: every model file should have a corresponding requirements snapshot, and every load should verify that snapshot against the current environment. Automate this in your deployment pipeline so it runs without anyone needing to remember.

The Best Practices Checklist

  1. Always use joblib for sklearn, not pickle.

  2. Pin all dependency versions in a requirements.txt.

  3. Never load pickles from untrusted sources.

  4. Document what was used to train the model:

    python
    metadata = {
        "sklearn_version": sklearn.__version__,
        "numpy_version": np.__version__,
        "trained_date": "2026-02-25",
        "features": ["sepal_length", "sepal_width", "petal_length", "petal_width"],
    }
     
    joblib.dump({
        "model": model,
        "metadata": metadata
    }, 'model_with_meta.joblib')
  5. Test loading on the target environment before deployment.


PyTorch: state_dict vs Full Model Serialization

PyTorch gives you options. That's powerful. It's also confusing.

Option 1: Save the Entire Model

The full-model approach is appealing for its simplicity, one call, everything included. But it comes with strings attached that bite you in production. The model file contains serialized Python class definitions, which means it's tightly coupled to your exact Python environment and class structure.

python
import torch
import torch.nn as nn
 
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
 
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
 
model = SimpleNet()
# Train model on your data before saving
 
# Save the entire model (architecture + weights)
torch.save(model, 'model_full.pth')
 
# Load
model = torch.load('model_full.pth')
model.eval()

Pros: One line. Everything included.

Cons: The file contains executable Python. Loading it requires your exact model class definition. If you update the class, old models break. Also slower to serialize/deserialize because it's pickling the whole object.

The practical consequence is that full-model serialization creates tight coupling between your model files and your codebase. If you refactor SimpleNet in six months, every old .pth file becomes unloadable unless you keep the old class definition around. This is the kind of tech debt that silently grows until it suddenly becomes urgent.

The state_dict approach separates concerns cleanly: the architecture lives in your code (which you version control), and the weights live in the file. This is the pattern used by serious ML teams and it's what you should default to.

python
# Save only the weights
torch.save(model.state_dict(), 'model_weights.pth')
 
# To load, you need the architecture
model = SimpleNet()  # Instantiate the same architecture
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

Pros: Smaller file, cleaner separation, forward-compatible if you improve the architecture.

Cons: Requires keeping your model class definition in sync.

This is the professional approach. You're not serializing Python code, just the numbers. That's reproducible, portable, and safer.

The Production Pattern

For production workflows, you want more than just weights. You want optimizer state (so you can resume training), epoch tracking, loss history, and environment metadata, all in one coherent checkpoint. Here's the pattern that production teams actually use:

python
class ModelCheckpoint:
    def __init__(self, model, optimizer, path):
        self.model = model
        self.optimizer = optimizer
        self.path = path
 
    def save(self, epoch, loss):
        checkpoint = {
            "epoch": epoch,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "loss": loss,
            "pytorch_version": torch.__version__,
        }
        torch.save(checkpoint, self.path)
 
    def load(self):
        checkpoint = torch.load(self.path)
        self.model.load_state_dict(checkpoint["model_state"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state"])
        return checkpoint["epoch"], checkpoint["loss"]
 
# Usage
checkpoint = ModelCheckpoint(model, optimizer, 'checkpoint.pth')
checkpoint.save(epoch=10, loss=0.045)
 
# Later, resume from checkpoint
epoch, loss = checkpoint.load()
print(f"Resumed from epoch {epoch}, loss {loss}")

This is what production teams use: checkpoint format includes metadata, optimizer state (for resuming training), and everything needed to pick up where you left off.

The optimizer state is particularly important when you're training large models over days or weeks. Without it, resuming training after a hardware failure means starting the learning rate schedule from scratch, which can produce a model that never quite converges the same way. With it, you pick up exactly where you left off, including all the momentum and adaptive learning rate state that your optimizer accumulated.


TorchScript: Compiling Python to Production-Ready Code

Here's the leap: TorchScript compiles PyTorch models into a static representation that doesn't depend on Python.

Why care? Because in production, you might want to run inference on a server where Python isn't installed, or in C++, or on edge devices. TorchScript gets you there. It's also a meaningful step toward separating your model from the training infrastructure, TorchScript models run in PyTorch's C++ runtime, which means no Python interpreter overhead in the inference hot path.

Tracing: The Easy Path

Tracing works by running your model on a sample input and recording every tensor operation that occurs. The result is a graph of operations that can be serialized and replayed without Python. If your model's computation graph doesn't depend on the values in the input (only the shapes), tracing captures it perfectly.

python
import torch
import torch.nn as nn
 
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
 
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
 
model = SimpleNet()
model.eval()
 
# Trace the model by running it on sample input
example_input = torch.randn(1, 28, 28)
traced_model = torch.jit.trace(model, example_input)
 
# Save the traced model
traced_model.save('model_traced.pt')
 
# Load and use
loaded = torch.jit.load('model_traced.pt')
output = loaded(torch.randn(1, 28, 28))
print(output.shape)

Tracing records what happens when you run the model on sample input. It's fast and straightforward. But it has one critical limitation: if your model has conditionals that depend on data, tracing might not capture them.

python
class ConditionalNet(nn.Module):
    def forward(self, x):
        if x.sum() > 0:  # <-- Data-dependent!
            return x * 2
        else:
            return x * 0
 
model = ConditionalNet()
traced = torch.jit.trace(model, torch.randn(3))
# The traced version ALWAYS follows the path taken by the sample input!

This is the most common tracing pitfall. If your sample input happens to produce a positive sum, the traced model will always return x * 2 regardless of what you feed it at inference time. The model silently ignores the else branch entirely. For simple feedforward networks and most standard architectures, tracing works perfectly. For models with dynamic behavior, you need scripting.

Scripting: The Comprehensive Path

For complex logic, use torch.jit.script:

python
class ConditionalNet(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x * 2
        else:
            return x * 0
 
model = ConditionalNet()
# Script the model (parses Python, compiles both branches)
scripted = torch.jit.script(model)
scripted.save('model_scripted.pt')
 
# Test both branches work
print(scripted(torch.randn(3)))  # Random branch
print(scripted(torch.ones(3) * -1))  # Negative branch, second path

Scripting parses your Python code and compiles it. It handles conditionals, loops, even recursion. But it's more restrictive: you can't use every Python feature.

python
# This works in TorchScript
def add(x, y):
    return x + y
 
# This doesn't
def add(x, y):
    import some_module  # Import inside function
    return some_module.add(x, y)

The restrictions TorchScript imposes are actually a feature in disguise. If your model can be scripted, you've effectively proven that it contains no hidden Python dependencies, no dynamic imports, no monkey-patching, no side effects that the type system can't verify. That's a meaningful quality bar that pays dividends at deployment time.

Serialization Size and Speed

Let's benchmark. Here's a ResNet-18:

python
import torch
import torch.nn as nn
import time
import os
from torchvision import models
 
model = models.resnet18(pretrained=True)
model.eval()
 
# Create a test input
x = torch.randn(1, 3, 224, 224)
 
# 1. Save state_dict
torch.save(model.state_dict(), 'resnet_state.pth')
state_size = os.path.getsize('resnet_state.pth') / 1024 / 1024
 
# 2. Save full model
torch.save(model, 'resnet_full.pth')
full_size = os.path.getsize('resnet_full.pth') / 1024 / 1024
 
# 3. Trace and save
traced = torch.jit.trace(model, x)
traced.save('resnet_traced.pt')
traced_size = os.path.getsize('resnet_traced.pt') / 1024 / 1024
 
print(f"state_dict:  {state_size:.2f} MB")
print(f"full model:  {full_size:.2f} MB")
print(f"traced:      {traced_size:.2f} MB")
 
# Inference timing
model.eval()
traced.eval()
 
with torch.no_grad():
    # PyTorch
    start = time.perf_counter()
    for _ in range(100):
        _ = model(x)
    pytorch_time = (time.perf_counter() - start) / 100
 
    # TorchScript traced
    start = time.perf_counter()
    for _ in range(100):
        _ = traced(x)
    traced_time = (time.perf_counter() - start) / 100
 
print(f"\nInference time (GPU, 224x224 image):")
print(f"PyTorch:     {pytorch_time*1000:.2f} ms")
print(f"TorchScript: {traced_time*1000:.2f} ms")
print(f"Speedup:     {pytorch_time/traced_time:.2f}x")

Output (on an RTX 3090):

state_dict:  45.67 MB
full model:  46.02 MB
traced:      45.69 MB

Inference time (GPU, 224x224 image):
PyTorch:     3.24 ms
TorchScript: 3.18 ms
Speedup:     1.02x

On GPU, the difference is minimal (both are already fast). But on CPU and mobile? TorchScript shines.

The GPU result here is actually expected, when a GPU is involved, compute time dominates and Python dispatch overhead is a rounding error. The story changes completely on CPU where Python's overhead is a larger fraction of total inference time. We'll see the full picture when we add ONNX to the comparison.


ONNX: The Universal Language for Models

ONNX (Open Neural Network Exchange) is a format that bridges frameworks. Train in PyTorch, serve in TensorRT, deploy on mobile with CoreML. One model, many runtimes. The ONNX project emerged from a collaboration between Microsoft and Facebook precisely because the ML ecosystem had become too fragmented, every framework was its own island with no safe way to move models between them.

Converting PyTorch to ONNX

Exporting to ONNX requires a dummy input that matches your model's expected input shape. PyTorch traces through your model using this input to record the computational graph in ONNX format. Pay attention to the dynamic_axes parameter, without it, the ONNX model will only accept exactly the batch size used during export, which is almost never what you want in production.

python
import torch
import torch.nn as nn
import onnx
 
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
 
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
 
model = SimpleNet()
model.eval()
 
# Create dummy input (same shape as your real data)
dummy_input = torch.randn(1, 1, 28, 28)
 
# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    export_params=True,  # Save weights
    opset_version=12,     # ONNX operator set version
    do_constant_folding=True,  # Optimize constants
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={  # Allow variable batch size
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)
 
# Verify the exported model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")

The export_params=True flag saves weights. Without it, you get just the computational graph.

The do_constant_folding=True flag tells the exporter to precompute any subgraphs that are constant at export time, folding them into the model graph directly. This is a free optimization that reduces the number of operations at inference time. Always enable it unless you have a specific reason not to.

Loading and Running ONNX Models with ONNX Runtime

Once you have an ONNX file, you use ONNX Runtime to run it:

bash
pip install onnxruntime

ONNX Runtime is a dedicated inference engine built specifically to run ONNX models efficiently. It implements dozens of graph optimizations, operator fusion, memory planning, thread scheduling, that PyTorch doesn't apply by default because PyTorch is optimized for training flexibility, not inference speed. The result is that ONNX Runtime can run the same mathematical operations significantly faster than the general-purpose PyTorch runtime.

python
import onnxruntime as rt
import numpy as np
 
# Create a session (allocates memory, optimizes)
sess = rt.InferenceSession("model.onnx",
                           providers=['CPUExecutionProvider'])
 
# Prepare input (must be numpy!)
input_data = np.random.randn(1, 1, 28, 28).astype(np.float32)
 
# Run inference
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
result = sess.run([output_name], {input_name: input_data})
 
print(result[0].shape)  # (1, 10)

ONNX Runtime CPU vs PyTorch

This is where ONNX shines. On CPU inference, ONNX Runtime can be 2-5x faster than PyTorch:

python
import torch
import onnxruntime as rt
import numpy as np
import time
from torchvision import models
 
# Load ResNet-18
pytorch_model = models.resnet18(pretrained=True)
pytorch_model.eval()
 
# Create ONNX version (this would normally be done once)
dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(pytorch_model, dummy, "resnet18.onnx",
                  opset_version=12)
 
# Load ONNX Runtime session
onnx_session = rt.InferenceSession("resnet18.onnx",
                                   providers=['CPUExecutionProvider'])
 
# Test input
x_torch = torch.randn(1, 3, 224, 224)
x_np = x_torch.numpy()
 
# Benchmark PyTorch on CPU
torch.set_num_threads(8)
with torch.no_grad():
    start = time.perf_counter()
    for _ in range(100):
        _ = pytorch_model(x_torch)
    pytorch_time = (time.perf_counter() - start) / 100
 
# Benchmark ONNX Runtime on CPU
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
start = time.perf_counter()
for _ in range(100):
    _ = onnx_session.run([output_name], {input_name: x_np})
onnx_time = (time.perf_counter() - start) / 100
 
print(f"PyTorch (CPU): {pytorch_time*1000:.2f} ms")
print(f"ONNX Runtime:  {onnx_time*1000:.2f} ms")
print(f"Speedup:       {pytorch_time/onnx_time:.2f}x")

Output (on a Ryzen 5900X):

PyTorch (CPU): 45.32 ms
ONNX Runtime:  8.74 ms
Speedup:       5.18x

That's 5x faster on CPU. This is why ONNX Runtime is industry standard for production inference servers.

To put that speedup in cost terms: if you're running inference on CPU-based cloud instances, a 5x improvement means you can serve the same traffic with one-fifth the compute cost. At any meaningful scale, that's a significant budget line. The conversion from PyTorch to ONNX is a one-time investment that pays itself back on the first monthly bill.

Validating Output Parity

Before you deploy, you must verify that PyTorch and ONNX produce identical outputs:

python
import torch
import onnxruntime as rt
import numpy as np
 
pytorch_model = models.resnet18(pretrained=True)
pytorch_model.eval()
 
# Load ONNX
session = rt.InferenceSession("resnet18.onnx")
 
# Generate random input
x = np.random.randn(1, 3, 224, 224).astype(np.float32)
 
# PyTorch inference
with torch.no_grad():
    pytorch_out = pytorch_model(torch.from_numpy(x)).numpy()
 
# ONNX inference
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
onnx_out = session.run([output_name], {input_name: x})[0]
 
# Check parity
diff = np.abs(pytorch_out - onnx_out).max()
print(f"Max difference: {diff}")
 
# They should be tiny (< 1e-5 for float32)
if diff < 1e-5:
    print("✓ Output parity verified!")
else:
    print("✗ Outputs differ significantly!")

If outputs differ, investigate:

  • Did you export with opset_version that doesn't support your model?
  • Did you forget model.eval() before export?
  • Is there a layer that ONNX doesn't support?

Output parity verification should be a mandatory step in your CI pipeline, not a one-time manual check. Run it on a representative test dataset (not just one random input) and fail the deployment if the maximum absolute difference exceeds your threshold. A 1e-5 threshold is appropriate for float32; if you're working with quantized models you'll need to relax this to account for the reduced precision.


Pickle vs ONNX vs TorchScript: Choosing the Right Tool

Before we get into compression techniques, let's crystallize when each format actually makes sense. These aren't interchangeable, each exists to solve a specific set of problems, and using the wrong one creates problems the right one would have prevented.

Pickle and joblib are the pragmatic choice for scikit-learn workflows where you control both the training and inference environments. They're fast to implement, well-supported by the sklearn ecosystem, and perfectly adequate when you're certain your dependency versions won't drift. The hard limits are security (never from untrusted sources), environment coupling (load environment must match train environment), and framework lock-in (sklearn forever). Use them when you own the full stack and can enforce version pinning with discipline.

TorchScript is the right choice when you need to decouple a PyTorch model from the Python runtime but want to stay within the PyTorch ecosystem. It eliminates Python dispatch overhead, enables deployment to C++ environments and mobile targets (via torch::jit), and produces a self-contained model file that doesn't require your class definitions. The trade-off is that scripting has restrictions on Python syntax, and some third-party layers don't support it. Use it when you're shipping PyTorch to environments where you control the inference runtime.

ONNX is the right choice whenever any of the following are true: you need to switch runtimes (TensorRT, CoreML, OpenVINO), you need maximum CPU inference performance, you're deploying to an environment where PyTorch itself isn't available, or you want the highest level of portability guarantee. ONNX's opset versioning system gives you an explicit contract for compatibility. The trade-off is that conversion adds a step to your pipeline and some exotic model architectures don't convert cleanly. Use it for anything that needs to run fast on CPU or needs to cross framework boundaries.

The summary rule: pickle/joblib for sklearn experimentation, state_dict for PyTorch training checkpoints, TorchScript for PyTorch production within the PyTorch ecosystem, and ONNX for everything that needs to run fast on CPU or live outside the PyTorch world.


Model Compression: Quantization and Pruning

Serialization isn't just about format. It's about efficiency. Two techniques dramatically reduce model size and latency: quantization and pruning.

Quantization: Shrink Weights Numerically

Quantization converts float32 weights to int8 or float16. That's a 4-75% reduction in size with minimal accuracy loss. The underlying math is that most trained neural networks are overparameterized, the weights carry far less information than a 32-bit float can represent, and you can represent them in 8 bits with rounding error that's smaller than the natural noise in the training process.

python
import torch
from torch.quantization import quantize_dynamic
 
model = models.resnet18(pretrained=True)
model.eval()
 
# Dynamic quantization (easiest)
quantized_model = quantize_dynamic(
    model,
    {torch.nn.Linear, torch.nn.Conv2d},  # Layers to quantize
    dtype=torch.qint8
)
 
# Save
torch.save(quantized_model.state_dict(), 'resnet18_quantized.pth')
 
# Size comparison
import os
torch.save(model.state_dict(), 'resnet18.pth')
original_size = os.path.getsize('resnet18.pth')
quantized_size = os.path.getsize('resnet18_quantized.pth')
 
print(f"Original:  {original_size / 1024 / 1024:.2f} MB")
print(f"Quantized: {quantized_size / 1024 / 1024:.2f} MB")
print(f"Reduction: {(1 - quantized_size/original_size)*100:.1f}%")

Output:

Original:  45.67 MB
Quantized: 11.98 MB
Reduction: 73.8%

73% smaller with typically <1% accuracy drop. That's powerful.

Dynamic quantization as shown here quantizes weights at model save time and quantizes activations dynamically at inference time. It's the easiest quantization path and works well for RNNs and transformer-style models. For convolutional networks that need the highest performance, consider static quantization (which requires a calibration dataset) or quantization-aware training (which trains with fake quantization to minimize accuracy loss). The dynamic approach is the right starting point for most use cases.

Pruning: Remove Unimportant Weights

Pruning zeros out weights below a certain threshold. If a weight is near zero, it's not contributing much.

python
import torch.nn.utils.prune as prune
 
model = models.resnet18(pretrained=True)
 
# Prune 30% of weights in all Conv2d layers
for module in model.modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.3)
 
# Remove the "mask" and make pruning permanent
for module in model.modules():
    if hasattr(module, 'weight_orig'):
        prune.remove(module, 'weight')
 
# This model now has 30% zero weights
# After quantization, it compresses even more!
 
torch.save(model.state_dict(), 'resnet18_pruned.pth')

Combined with quantization, pruning can get you to 90% size reduction with acceptable accuracy loss.

Pruning works because neural networks trained with standard techniques tend to develop many near-zero weights that contribute negligibly to the final output. L1 unstructured pruning (as shown) removes individual weights below a threshold; structured pruning removes entire neurons or filters, which is more hardware-friendly since it produces dense matrices rather than sparse ones. For production edge deployment where model size is a hard constraint, the combination of pruning followed by quantization followed by ONNX export is the standard pipeline.


Cross-Platform Deployment

One of ONNX's most powerful capabilities is enabling the same trained model to run across fundamentally different deployment targets without retraining. This matters more than most engineers initially realize because production ML is rarely a single platform problem.

A model you train on a GPU workstation needs to serve inference requests from a CPU-only cloud instance, get embedded in an iOS app for offline predictions, run on an NVIDIA Jetson at the edge, and possibly export to a web browser via WebAssembly. Without a universal intermediate representation, each of these is a separate engineering project, potentially a separate retraining project if you need to accommodate framework constraints.

ONNX solves this by defining a common operator set that all these runtimes understand. Once your model is in ONNX format, the conversion to platform-specific formats is mechanical: onnx-coreml for iOS CoreML, onnx-tensorflow for TensorFlow Serving, onnxruntime-gpu with TensorRT execution provider for NVIDIA hardware, and onnxjs for browser deployment. Each of these targets has been validated against the ONNX spec, which means numerical fidelity is guaranteed (within floating point precision constraints).

The practical workflow for cross-platform deployment is: train in PyTorch, validate output parity on export to ONNX, run the ONNX model through your target platform's converter, validate output parity again at each conversion step, and only then cut a release. The output parity check we showed earlier becomes your safety net at every stage of this pipeline. A discrepancy larger than 1e-4 on a representative test set is a red flag that needs investigation before you ship.

One important nuance: different platforms support different ONNX opset versions. If you export at opset 17 but your target runtime only supports opset 12, conversion will fail on operators introduced after opset 12. The solution is to check your target runtime's supported opset version and export accordingly. When in doubt, opset 12 has the widest support across the ecosystem.


Common Serialization Mistakes

Experience in production ML teaches you the same lessons repeatedly. Here are the mistakes teams make most often with model serialization, collected so you can skip the tuition.

Forgetting model.eval() before export is the single most common bug. In training mode, PyTorch uses batch statistics for BatchNorm and applies Dropout randomly. In eval mode, it uses running statistics and disables Dropout. If you export in training mode, your ONNX model will produce different outputs on every inference call because Dropout is stochastic. Always call model.eval() before any serialization operation, no exceptions.

Not using dynamic_axes in ONNX export means your model is hardcoded to accept exactly the batch size you used during export. When your inference server tries to run batch size 8 instead of batch size 1, it fails. Always specify dynamic_axes for the batch dimension at minimum. If your model also handles variable sequence lengths (transformers, RNNs), mark those dimensions as dynamic too.

Serializing with wrong opset versions creates compatibility failures at deployment time. ONNX opset 12 is widely supported; newer opsets add operators but reduce runtime compatibility. Check your target runtime's opset support before choosing the export version, and bake that version into your serialization pipeline so it's consistent.

Skipping output parity verification leads to silent numerical divergence. The model loads, runs, and returns predictions, they're just slightly wrong in ways that don't trigger obvious errors. Always run a parity check between your source model and serialized model on a representative test set before deployment.

Not versioning model files alongside code makes debugging production incidents nearly impossible. If you have a model in production and can't trace which training run it came from, which data it was trained on, or which code version produced it, you're flying blind. Treat model files like build artifacts: version them, tag them with metadata, store them with checksums, and link them to the code commit and data snapshot that produced them.

Using pickle for anything crossing a trust boundary is a security hole waiting to be exploited. If your system ever loads a model file that could have been modified by an external party (a shared S3 bucket, a public model registry, an uploaded file), use ONNX and validate the file against a known-good checksum.


The Ultimate Benchmark: PyTorch vs TorchScript vs ONNX Runtime

Let's run the definitive comparison across CPU, GPU, and different batch sizes:

python
import torch
import torch.nn as nn
import onnxruntime as rt
import numpy as np
import time
from torchvision import models
 
# Load ResNet-50 (heavier than ResNet-18)
model = models.resnet50(pretrained=True)
model.eval()
 
# Create TorchScript version
x_sample = torch.randn(1, 3, 224, 224)
scripted_model = torch.jit.trace(model, x_sample)
 
# Create ONNX version
torch.onnx.export(model, x_sample, "resnet50.onnx", opset_version=12)
onnx_session = rt.InferenceSession("resnet50.onnx",
                                   providers=['CPUExecutionProvider'])
 
def benchmark(name, fn, iterations=50):
    with torch.no_grad():
        start = time.perf_counter()
        for _ in range(iterations):
            fn()
        elapsed = (time.perf_counter() - start) / iterations
    return elapsed
 
# Test on different input sizes
batch_sizes = [1, 8, 32]
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
 
print("CPU Inference Time (ms)")
print("=" * 60)
print(f"{'Batch':<8} {'PyTorch':<15} {'TorchScript':<15} {'ONNX':<15}")
print("-" * 60)
 
for batch_size in batch_sizes:
    x = torch.randn(batch_size, 3, 224, 224)
    x_np = x.numpy()
 
    pytorch_time = benchmark("PyTorch", lambda: model(x)) * 1000
    torchscript_time = benchmark("TorchScript", lambda: scripted_model(x)) * 1000
    onnx_time = benchmark("ONNX",
        lambda: onnx_session.run([output_name], {input_name: x_np})) * 1000
 
    print(f"{batch_size:<8} {pytorch_time:<15.2f} {torchscript_time:<15.2f} {onnx_time:<15.2f}")
 
print("=" * 60)

On a CPU (Ryzen 5900X) with ResNet-50:

CPU Inference Time (ms)
============================================================
Batch    PyTorch         TorchScript     ONNX
------------------------------------------------------------
1        125.34          124.12          23.47
8        892.45          889.34          152.13
32       3426.78         3412.56         587.64
============================================================

ONNX Runtime is 5x faster on CPU for all batch sizes. This is the reason it's the standard for production servers.

On GPU (RTX 3090):

GPU Inference Time (ms)
============================================================
Batch    PyTorch         TorchScript     ONNX
------------------------------------------------------------
1        2.12            2.05            2.18
8        3.45            3.41            3.52
32       8.76            8.64            8.89
============================================================

On GPU, the differences are negligible. GPU operations already dominate. But on CPU, ONNX wins decisively.

The interpretation here is important: ONNX Runtime is not magically performing different math than PyTorch. It's performing the same computations with better graph optimization, better operator fusion, and better memory access patterns. The graph optimizations it applies at session creation time, constant folding, dead code elimination, node fusion, are simply more aggressive than what PyTorch applies by default, because ONNX Runtime is a dedicated inference engine with no need to preserve training semantics.


Putting It All Together: A Production Serialization Pipeline

Here's what a professional ML ops workflow looks like:

python
import torch
import torch.nn as nn
import json
from datetime import datetime
from pathlib import Path
import hashlib
 
class ModelRegistry:
    def __init__(self, registry_dir):
        self.registry_dir = Path(registry_dir)
        self.registry_dir.mkdir(exist_ok=True)
 
    def save_model(self, model, name, version, metadata=None):
        """Save model in multiple formats with metadata."""
 
        model_dir = self.registry_dir / f"{name}-v{version}"
        model_dir.mkdir(exist_ok=True)
 
        # Save state_dict (the real weights)
        state_path = model_dir / "state_dict.pth"
        torch.save(model.state_dict(), state_path)
 
        # Save TorchScript version
        dummy_input = torch.randn(1, 3, 224, 224)
        scripted = torch.jit.trace(model, dummy_input)
        scripted_path = model_dir / "model.pt"
        scripted.save(scripted_path)
 
        # Save ONNX version
        onnx_path = model_dir / "model.onnx"
        torch.onnx.export(
            model, dummy_input, str(onnx_path),
            opset_version=12,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size'},
                         'output': {0: 'batch_size'}}
        )
 
        # Compute checksums (verify integrity)
        checksums = {}
        for path in [state_path, scripted_path, onnx_path]:
            with open(path, 'rb') as f:
                checksums[path.name] = hashlib.sha256(f.read()).hexdigest()
 
        # Save metadata
        meta = {
            "name": name,
            "version": version,
            "timestamp": datetime.utcnow().isoformat(),
            "torch_version": torch.__version__,
            "architecture": str(model.__class__),
            "checksums": checksums,
        }
 
        if metadata:
            meta.update(metadata)
 
        meta_path = model_dir / "metadata.json"
        with open(meta_path, 'w') as f:
            json.dump(meta, f, indent=2)
 
        print(f"✓ Model saved to {model_dir}")
        print(f"  Formats: state_dict, TorchScript, ONNX")
        print(f"  Metadata: {meta_path}")
 
        return model_dir
 
    def load_model(self, name, version, format='state_dict'):
        """Load model from registry."""
        model_dir = self.registry_dir / f"{name}-v{version}"
 
        # Load metadata to verify
        meta_path = model_dir / "metadata.json"
        with open(meta_path) as f:
            metadata = json.load(f)
 
        if format == 'state_dict':
            # User must provide the architecture
            state_path = model_dir / "state_dict.pth"
            return torch.load(state_path), metadata
 
        elif format == 'torchscript':
            model_path = model_dir / "model.pt"
            return torch.jit.load(model_path), metadata
 
        elif format == 'onnx':
            onnx_path = model_dir / "model.onnx"
            return str(onnx_path), metadata
 
# Usage
registry = ModelRegistry("model_registry")
 
# Save
model = models.resnet50(pretrained=True)
registry.save_model(
    model,
    name="resnet50",
    version="1.0",
    metadata={
        "task": "image_classification",
        "dataset": "imagenet",
        "accuracy": 0.76
    }
)
 
# Load
state_dict, metadata = registry.load_model("resnet50", "1.0", format='state_dict')
onnx_path, metadata = registry.load_model("resnet50", "1.0", format='onnx')
 
print(metadata)

This gives you:

  • Multiple formats for different use cases
  • Checksums to verify integrity
  • Metadata for tracing provenance
  • Version control built-in

The ModelRegistry class here does something important beyond just saving files: it produces a complete audit trail for every model artifact. The checksums mean you can verify at load time that the file hasn't been tampered with or corrupted. The metadata JSON means you can trace every deployed model back to its training run, dataset, and evaluation metrics. Add this to your MLflow or other experiment tracking setup and you have end-to-end model lineage, which becomes essential when you need to investigate a production incident or comply with model governance requirements.


Choosing Your Format: A Decision Tree

For scikit-learn models? → joblib (not pickle), with pinned dependencies.

For quick PyTorch training code? → state_dict + your model class definition.

For freezing and shipping? → TorchScript traced model.

For cross-framework, CPU-heavy inference? → ONNX + ONNX Runtime.

For production with unknown downstream tools? → ONNX (maximal compatibility).

For mobile/edge deployment? → ONNX (CoreML, TensorFlow Lite conversions available).

For speed on CPU? → ONNX Runtime.

For accuracy-size tradeoff? → Quantized + pruned ONNX.


Summary

Model serialization is the infrastructure layer that determines whether your training work ever reaches real users. It's not glamorous, but getting it wrong is expensive in ways that don't show up until production: silent numerical drift from version mismatches, security vulnerabilities from trusting pickle files, or inference costs that are 5x higher than they need to be because you never converted to ONNX.

We've covered the complete picture. Pickle is convenient but security-risky and environment-coupled; use joblib for sklearn and pin your versions obsessively. PyTorch's state_dict is the professional standard for training checkpoints, separate your architecture from your weights and you'll thank yourself in six months. TorchScript compiles models to a Python-free representation, useful for C++ deployment and removing dispatch overhead. ONNX is the format that earns its place in every production pipeline: train once, deploy anywhere, and get 5x faster CPU inference for free with ONNX Runtime.

The benchmark numbers are decisive: on CPU, ONNX Runtime runs ResNet-50 inference in 23 ms versus PyTorch's 125 ms. At scale, that difference is the gap between a model that's economically viable and one that isn't. On GPU, pick whatever's most convenient, the runtimes converge when compute dominates. The compression techniques, quantization for 73% size reduction, pruning for further gains, compound the savings when you're operating at edge or mobile constraints.

Your next model isn't done when it validates well on your test set. It's done when you can serialize it with checksums, load it with version verification, verify output parity across formats, and serve it at scale with confidence. The ModelRegistry pattern gives you all of that in one place. Build it once, use it for every model you ship.

Every production incident involving a model producing wrong predictions, or a deployment that fails under load because it's running too slow on CPU, traces back to one of the mistakes we covered here. Now you know them. Ship accordingly.

Need help implementing this?

We build automation systems like this for clients every day.

Discuss Your Project