Building Neural Networks with PyTorch nn.Module

Neural networks are the engine behind virtually every modern AI breakthrough, from the language models that write code to the vision systems that identify tumors in medical scans. But understanding what they actually are, mechanically, changes how you build them. At their core, a neural network is a parameterized function: it takes an input, multiplies it through a series of learned weight matrices, applies non-linear transformations, and produces an output. The magic is in those learned weights, trained through billions of gradient updates until the function reliably maps inputs to correct outputs.
PyTorch's nn.Module system is how you express that parameterized function in code. It gives you a disciplined way to define architecture, register learnable parameters, compose layers, manage devices, and switch between training and evaluation modes. Without it, you'd be manually tracking thousands of weight tensors, remembering which ones need gradients, and writing brittle serialization code. With it, those concerns vanish into a clean, consistent API.
What makes nn.Module particularly powerful is its composability. Every layer is a module, every block is a module, and every full network is a module. You can nest them arbitrarily deep, and PyTorch traverses the tree automatically to collect parameters, move tensors to devices, and propagate gradients. This recursive structure is why you can build a ResidualBlock, wrap twelve of them inside a ResNet, and call .parameters() on the outer network to get every single weight across the entire hierarchy.
Before we dive into code, it's worth appreciating what problem nn.Module is actually solving. When you have a network with fifty layers and three million parameters, manual bookkeeping becomes impossible. You need a system that tracks parameters automatically, handles the training/evaluation mode distinction, and provides introspection tools. nn.Module is that system. By the end of this article, you'll understand how to construct networks from simple building blocks, initialize weights properly, inspect your models, and save and load them like a pro. We'll build three architectures of increasing complexity to show you exactly how this scales from toy networks to production models.
Table of Contents
- Understanding nn.Module: The Foundation
- nn.Module Architecture
- The Three Essential Methods
- **init**: Define Your Layers
- forward: Compute the Output
- parameters() and named_modules(): Introspection
- Layer Types and When to Use Them
- Common Layers You'll Actually Use
- Linear (Fully Connected)
- Activation Functions
- BatchNorm1d (Batch Normalization)
- Dropout
- Building Networks: Three Architectures
- Architecture 1: Shallow MLP (Multi-Layer Perceptron)
- Architecture 2: Deep MLP with Regularization
- Architecture 3: Residual Block Network
- Forward Pass Design
- Weight Initialization: It Matters More Than You Think
- Xavier (Glorot) Initialization
- Kaiming (He) Initialization
- Orthogonal Initialization
- Counting Parameters and Inspecting Models
- Common PyTorch Mistakes
- Saving and Loading Models: state_dict() vs Full Model
- Option 1: state_dict() (Recommended)
- Option 2: Full Model (Use with Caution)
- Moving Models to GPU
- DataParallel: Multi-GPU Training (Optional)
- Putting It All Together: A Complete Training Example
- Transitioning Between Training and Evaluation Modes
- Common Pitfalls and How to Avoid Them
- Pitfall 1: Forgetting to Call super().**init**()
- Pitfall 2: Using Python Lists Instead of nn.ModuleList
- Pitfall 3: Activations in the Wrong Place
- Pitfall 4: Not Setting model.eval() Before Evaluation
- Advanced Pattern: Conditional Layers
- Inspecting Gradients
- Key Takeaways
Understanding nn.Module: The Foundation
At its heart, nn.Module is a Python class that wraps your model. Every neural network you'll build is a subclass of nn.Module. Think of it as a blueprint for organizing layers, parameters, and the forward computation.
The pattern is always the same: subclass nn.Module, define your layers in __init__, and describe how data flows through them in forward. This two-method structure separates what your network contains from what it does, a clean design that makes networks readable, debuggable, and reusable. Here is the minimal working example you will build every architecture from:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
model = SimpleNet()
print(model)Output:
SimpleNet(
(fc1): Linear(in_features=784, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
The print(model) output alone tells you something important: PyTorch knows exactly what layers your model contains and their configurations, without you having to manually list them. That awareness comes from the automatic registration that happens the moment you assign a layer to self.
Why does this structure matter? When you assign a layer like self.fc1 = nn.Linear(784, 128), PyTorch automatically registers that layer. This means:
- Parameters are tracked: PyTorch knows about all the weights and biases in your network
- Gradients flow correctly: When you call
.backward(), gradients propagate through all registered parameters - Device management is easy: Call
.to('cuda')and every parameter moves to GPU - Composition works: You can nest modules inside modules
Let me show you why that last point matters. Imagine trying to track 500 different weight matrices manually. With nn.Module, you just nest one module inside another, and PyTorch handles the bookkeeping.
nn.Module Architecture
Understanding the internal mechanics of nn.Module makes you a better architect. When you subclass nn.Module and call super().__init__(), PyTorch initializes several internal dictionaries under the hood: _parameters, _modules, _buffers, and _hooks. These are what power the automatic registration system.
When you write self.fc1 = nn.Linear(784, 128), Python's __setattr__ is intercepted by nn.Module. It checks whether the value being assigned is an nn.Parameter, an nn.Module, or a regular Python object, and routes it to the appropriate internal dictionary. Parameters end up in _parameters, submodules end up in _modules, and everything else flows through normally. This interception is the foundation of everything that follows: it is how PyTorch knows about your layers without requiring you to manually register them.
The _buffers dictionary holds tensors that need to move with the model to different devices but are not learned parameters. Batch normalization's running mean and running variance are the classic example, they track statistics across training batches but are never updated by the optimizer. When you call .to('cuda'), both parameters and buffers move together, ensuring your model is fully consistent on the target device.
Composability is the architectural payoff. Because every layer is itself an nn.Module, nesting is natural and recursive. A ResidualBlock can contain Linear layers, BatchNorm1d layers, and Dropout layers. A ResidualNetwork can contain a ModuleList of ResidualBlock instances. When PyTorch traverses the parameter tree, it walks this hierarchy depth-first, collecting every registered parameter from every nested module. This means your optimizer, your device transfers, and your state dict all work transparently across arbitrarily complex architectures, without any extra bookkeeping on your part.
The Three Essential Methods
Every nn.Module has three critical methods you'll use constantly:
init: Define Your Layers
This is where you instantiate all the layers your network will use. Every sublayer becomes an attribute of self. The key insight is that __init__ is purely about declaration, you are describing the components of your network, not activating them. Think of it as laying out your tools before starting work:
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 64)
self.layer2 = nn.Linear(64, 32)
self.layer3 = nn.Linear(32, 1)The super().__init__() call is crucial, it sets up all the internal machinery that makes parameter tracking work.
forward: Compute the Output
The forward() method defines how data flows through your network. This is where you use the layers you defined. Notice that forward receives data and returns transformed data, it is a pure description of the computation graph, which PyTorch's autograd engine will then differentiate automatically:
def forward(self, x):
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
x = torch.relu(x)
x = self.layer3(x)
return xWhen you call model(x), PyTorch internally calls forward(x). Yes, you never call forward() directly, you call the model like a function.
The reason you call model(x) rather than model.forward(x) is that nn.Module.__call__ does more than just run forward. It also runs any registered hooks, handles certain edge cases, and enables the gradient computation graph. Always use model(x):
x = torch.randn(32, 10) # batch of 32 samples, 10 features each
output = model(x) # This calls forward() automatically
print(output.shape) # torch.Size([32, 1])parameters() and named_modules(): Introspection
PyTorch provides two key methods to explore your network. These introspection tools become invaluable when debugging, you can verify that your architecture registered the layers you intended, and inspect the shapes of all parameters before committing to a training run:
# Get all parameters (weights and biases)
for param in model.parameters():
print(param.shape)
# torch.Size([64, 10]) <- fc1 weights
# torch.Size([64]) <- fc1 bias
# torch.Size([32, 64]) <- fc2 weights
# torch.Size([32]) <- fc2 bias
# torch.Size([1, 32]) <- fc3 weights
# torch.Size([1]) <- fc3 bias
# Get named parameters (useful for debugging)
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
# layer1.weight: torch.Size([64, 10])
# layer1.bias: torch.Size([64])
# layer2.weight: torch.Size([32, 64])
# ...
# Get all submodules
for name, module in model.named_modules():
print(f"{name}: {module}")These methods are lifesavers when you're debugging or trying to understand what's happening inside your network.
Layer Types and When to Use Them
Knowing which layer to reach for, and why, separates engineers who understand networks from those who cargo-cult architectures they found online. The choice of layer type is an architectural decision that reflects assumptions about your data's structure.
nn.Linear is the universal approximator building block for tabular and flat data. It learns a full affine transformation: every output neuron connects to every input neuron with a learned weight. Use it when your features don't have spatial or sequential structure, when the relationship between feature at position 5 and feature at position 95 is potentially as important as neighboring features. Most classification heads and regression outputs end with a Linear layer regardless of what came before it.
nn.Conv2d is the workhorse for image data, and its design reflects a key assumption: nearby pixels are more related than distant pixels, and the same patterns (edges, textures, shapes) appear at different locations in the image. Convolutional layers exploit this by learning small, reusable filters that slide across the input. The result is dramatically fewer parameters than equivalent Linear layers, and much better generalization on spatial data. If your input is images, audio spectrograms, or any data with local spatial structure, reach for convolutions.
nn.LSTM and nn.GRU handle sequential data where order matters and context from earlier steps influences later ones. Use them for time series, natural language, or any problem where the network needs memory of past inputs. They are significantly more complex to work with, you must manage hidden states explicitly, but they remain relevant for tasks where you need interpretable, deterministic sequence models without the compute cost of transformers.
Regularization layers like nn.Dropout and nn.BatchNorm1d are not really "layer types" in the architectural sense, they are training tools that you insert into existing architectures. Dropout prevents co-adaptation between neurons by randomly silencing them during training. Batch normalization stabilizes the distribution of activations across layers, allowing you to use higher learning rates and reducing sensitivity to initialization. Use both by default in any deep network; the performance benefits are consistent enough that the only reason to omit them is interpretability constraints or very small networks where overhead dominates.
Common Layers You'll Actually Use
Let's be practical. Here are the layers that show up in like 90% of neural networks:
Linear (Fully Connected)
The workhorse of dense networks. The math inside is simple, matrix multiplication plus an optional bias vector, but the interpretability and universality of Linear layers make them the default starting point for any new problem until you have evidence that specialized structure will help:
layer = nn.Linear(in_features=784, out_features=128, bias=True)
# Input: (batch_size, 784)
# Output: (batch_size, 128)
# Parameters: 784*128 + 128 = 100,480Activation Functions
These introduce non-linearity. No non-linearity = your network is just a series of matrix multiplications (useless). The reason non-linearity matters is mathematical: a composition of linear functions is itself linear. Without activation functions, stacking one hundred layers is equivalent to a single matrix multiply. Activations break that linearity, giving your network the capacity to learn curved decision boundaries:
# ReLU (Rectified Linear Unit) - most common
relu = nn.ReLU()
output = relu(torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]))
# tensor([0., 0., 0., 1., 2.])
# Sigmoid - output between 0 and 1
sigmoid = nn.Sigmoid()
output = sigmoid(torch.tensor([0.0]))
# tensor([0.5])
# Tanh - output between -1 and 1
tanh = nn.Tanh()
output = tanh(torch.tensor([0.0]))
# tensor([0.])ReLU dominates modern networks because it's simple, fast, and prevents the vanishing gradient problem. Sigmoid and Tanh are still useful in specific contexts (like LSTM gates), but ReLU is your default.
BatchNorm1d (Batch Normalization)
Stabilizes training by normalizing activations. The intuition behind batch norm is that if the distribution of activations shifts dramatically layer to layer, a problem called internal covariate shift, the later layers are constantly chasing a moving target. Batch norm fixes each layer's input distribution to be approximately standard normal, letting deeper layers train more stably:
# For fully-connected layers
bn = nn.BatchNorm1d(num_features=128)
# Input: (batch_size, 128)
# Output: (batch_size, 128) - normalized
# Typical usage
x = self.fc1(x)
x = self.bn1(x)
x = torch.relu(x)Batch norm has learnable parameters (scale and shift), so it's not just a preprocessing step, it's actually helping your network learn better.
Dropout
Prevents overfitting by randomly zeroing activations during training. The underlying idea is ensemble learning: by randomly disabling neurons during training, you are effectively training a different network on each batch. At inference time, the full network approximates the average of all those subnetworks, which tends to generalize better than any single one:
dropout = nn.Dropout(p=0.5) # Drop 50% of activations
# During training: randomly zeros 50% of inputs
# During inference: does nothing (automatically switches modes)
# Usage
x = torch.relu(self.fc1(x))
x = self.dropout(x) # Drops 50% of activationsThe key point: dropout only activates during training. When you call model.eval(), dropout turns off automatically. Same with batch norm, it uses running statistics instead of batch statistics.
Building Networks: Three Architectures
Let's move from theory to practice with three increasingly complex networks. Each teaches you something important about how nn.Module scales and how architectural decisions compound as your networks grow deeper.
Architecture 1: Shallow MLP (Multi-Layer Perceptron)
Simple, fast, good for learning the basics. This two-layer network is the "hello world" of deep learning, simple enough that you can trace the exact computation by hand, but already illustrating the core nn.Module pattern of declaring layers in __init__ and connecting them in forward:
class ShallowMLP(nn.Module):
def __init__(self, input_size=784, hidden_size=128, num_classes=10):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# Instantiate and test
model = ShallowMLP()
x = torch.randn(32, 784) # batch of 32 MNIST images
output = model(x)
print(output.shape) # torch.Size([32, 10])This is a bare-bones network. Two layers, one hidden layer with ReLU. No batch norm, no dropout. It'll work for simple datasets but won't scale well.
Architecture 2: Deep MLP with Regularization
Add depth, batch norm, and dropout for better generalization. The key design insight here is using nn.Sequential to build the hidden block pattern programmatically, rather than hard-coding five nearly-identical layer groups, we generate them in a loop and let Python do the repetitive work. This is an important habit as networks grow: write code that describes the pattern, not the individual instances:
class DeepMLP(nn.Module):
def __init__(self, input_size=784, hidden_sizes=[512, 256, 128],
num_classes=10, dropout_rate=0.2):
super().__init__()
layers = []
prev_size = input_size
# Hidden layers with batch norm and dropout
for hidden_size in hidden_sizes:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.BatchNorm1d(hidden_size))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
prev_size = hidden_size
# Output layer (no activation, no batch norm)
layers.append(nn.Linear(prev_size, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# Test it
model = DeepMLP()
x = torch.randn(32, 784)
output = model(x)
print(output.shape) # torch.Size([32, 10])
print(model)Notice we used nn.Sequential, it's a convenient container that chains layers together automatically. The forward pass just pipes data through each layer in order.
Let's look at what we built. The output of print(model) is worth studying carefully, it confirms that every layer was registered, shows the configuration of each one, and makes the BatchNorm parameters visible so you know the network has exactly the regularization structure you intended:
DeepMLP(
(net): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Dropout(p=0.2, inplace=False)
(4): Linear(in_features=512, out_features=256, bias=True)
(5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU()
(7): Dropout(p=0.2, inplace=False)
(8): Linear(in_features=256, out_features=128, bias=True)
(9): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU()
(11): Dropout(p=0.2, inplace=False)
(12): Linear(in_features=128, out_features=10, bias=True)
)
)
Each hidden layer gets batch norm and dropout. The output layer doesn't, why? Because you typically apply softmax/cross-entropy in your loss function, not in the model itself. Batch norm and dropout are only for hidden layers.
Architecture 3: Residual Block Network
Now let's get fancy. Residual connections (skip connections) let you stack many layers without gradient degradation. The residual idea, introduced in the original ResNet paper, is elegant: instead of asking each block to learn the full desired transformation H(x), we ask it to learn only the residual F(x) = H(x) - x. The skip connection adds back the original input so the effective output is F(x) + x. This small change makes training dramatically more stable at depth, gradients flow directly back through the skip connection, bypassing the transformation entirely if needed:
class ResidualBlock(nn.Module):
def __init__(self, in_features, hidden_features, dropout_rate=0.2):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.bn1 = nn.BatchNorm1d(hidden_features)
self.fc2 = nn.Linear(hidden_features, in_features)
self.bn2 = nn.BatchNorm1d(in_features)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
identity = x # Save the input for the skip connection
out = self.fc1(x)
out = self.bn1(out)
out = torch.relu(out)
out = self.dropout(out)
out = self.fc2(out)
out = self.bn2(out)
out += identity # Add the skip connection
out = torch.relu(out)
return out
class ResidualNetwork(nn.Module):
def __init__(self, input_size=784, num_blocks=4, block_features=256,
num_classes=10, dropout_rate=0.2):
super().__init__()
self.input_layer = nn.Linear(input_size, block_features)
self.bn_input = nn.BatchNorm1d(block_features)
# Stack residual blocks
self.blocks = nn.ModuleList([
ResidualBlock(block_features, block_features // 2, dropout_rate)
for _ in range(num_blocks)
])
self.output_layer = nn.Linear(block_features, num_classes)
def forward(self, x):
x = self.input_layer(x)
x = self.bn_input(x)
x = torch.relu(x)
for block in self.blocks:
x = block(x)
x = self.output_layer(x)
return x
# Test it
model = ResidualNetwork()
x = torch.randn(32, 784)
output = model(x)
print(output.shape) # torch.Size([32, 10])The skip connection (out += identity) is the magic. It lets gradients flow directly back through the connection, preventing them from vanishing in deep networks. That's why you can stack many residual blocks without problems.
Notice we used nn.ModuleList, it's like a Python list, but PyTorch registers all the modules inside it for parameter tracking. Never use a regular Python list for modules; PyTorch won't see the parameters.
Forward Pass Design
The forward method is where you make architectural decisions that fundamentally determine what your network can learn. It is not just plumbing, it is the computation graph that autograd will differentiate, and its structure determines how gradients flow backward through your model.
The simplest forward pass is a linear chain: each layer feeds into the next, and data flows in one direction. This is what nn.Sequential gives you, and it is perfectly sufficient for most feedforward architectures. But the moment you need branching, when different parts of your input need separate processing before being combined, or when you want skip connections, you need a custom forward method. The residual block above is the canonical example: identity = x branches the computation, and out += identity merges it back.
A subtlety that catches beginners: the order of operations in forward matters for how batch normalization and dropout interact. The conventional order is Linear -> BatchNorm -> Activation -> Dropout. The intuition is that batch norm works best on the raw linear output before non-linearity squashes the distribution, and dropout should come after activation so you are zeroing post-activation values rather than pre-activation ones. Swapping these orders does not break training, but the conventional order tends to perform better in practice.
Dynamic forward passes, where the computation graph changes based on input, are fully supported in PyTorch because autograd traces the actual execution, not a pre-compiled static graph. This means you can write if statements in forward that branch based on runtime values, loop over variable-length sequences, and even use different layers depending on training versus evaluation mode. This flexibility is one of the reasons PyTorch won significant adoption from researchers: your Python code is your model, with no translation layer between you and the underlying computation.
When designing forward, think about what your network needs to learn and whether the data flow supports it. If you have multiple input modalities, say, tabular features and text embeddings, process them through separate sub-networks and concatenate before the final classification head. If your problem has structure (spatial, sequential, hierarchical), let that structure inform the topology of your forward pass. The code is flexible; use it to encode your domain knowledge about the problem.
Weight Initialization: It Matters More Than You Think
Random initialization isn't just random, it matters a lot for training speed and final performance. Poor initialization can cause activations to saturate or explode before you even start training. The goal of initialization is to ensure that at the start of training, activations throughout the network have reasonable magnitudes, neither so large that they saturate non-linearities, nor so small that gradients vanish before reaching early layers.
Xavier (Glorot) Initialization
Good for networks with sigmoid/tanh activations. Xavier initialization was derived mathematically to keep the variance of activations roughly constant from layer to layer when using symmetric activations like tanh, the key insight being that variance should be preserved in both the forward and backward passes:
def init_xavier(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
model = DeepMLP()
model.apply(init_xavier)Xavier keeps activations in a healthy range by scaling the weights based on the number of inputs and outputs:
variance = 2 / (fan_in + fan_out)
Kaiming (He) Initialization
Better for networks with ReLU activations (modern networks). Kaiming initialization accounts for the fact that ReLU is not symmetric, it kills exactly half of all activations (every negative value becomes zero). Xavier's derivation assumed symmetric activations, so it under-scales weights for ReLU networks. Kaiming corrects this:
def init_kaiming(module):
if isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
model = DeepMLP()
model.apply(init_kaiming)Kaiming adjusts for ReLU's behavior, it accounts for the fact that ReLU kills half the activations (all negative values become 0).
Orthogonal Initialization
Useful for RNNs and other recurrent architectures. Orthogonal matrices preserve the norm of vectors they multiply, which prevents the repeated application of the same weight matrix (as happens in RNNs) from causing exponential growth or decay in hidden states:
nn.init.orthogonal_(module.weight, gain=1.0)For most feedforward networks, just use Kaiming. It's the right default.
Counting Parameters and Inspecting Models
You'll often need to know how many parameters your model has. More parameters = more capacity, but also more overfitting risk. The count is also a practical constraint: a model with 10 billion parameters cannot run on a consumer GPU, full stop. Getting a handle on parameter count early in architectural design prevents you from building something impractical.
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = ResidualNetwork()
print(f"Total parameters: {count_parameters(model):,}")
# Total parameters: 1,043,202For a nicer summary, use the torchsummary library (install it with pip install torchsummary). The torchsummary output goes further than a raw parameter count, it shows you the output shape at every layer, which is invaluable for catching shape mismatches before you commit to a training run:
from torchsummary import summary
model = ResidualNetwork()
summary(model, input_size=(1, 784))Output:
================================================================
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 256] 200,960
BatchNorm1d-2 [-1, 256] 512
ResidualBlock-3 [-1, 256] 197,888
ResidualBlock-4 [-1, 256] 197,888
ResidualBlock-5 [-1, 256] 197,888
ResidualBlock-6 [-1, 256] 197,888
Linear-7 [-1, 10] 2,570
================================================================
Total params: 1,043,202
Trainable params: 1,043,202
Non-trainable params: 0
================================================================
This shows you exactly what's happening layer by layer. It's invaluable for debugging and understanding your network's capacity.
Common PyTorch Mistakes
Everyone makes these mistakes. Knowing them in advance saves you hours of debugging sessions that end with a facepalm.
The most insidious mistake is forgetting model.eval() before evaluation. When you run inference with the model still in training mode, batch normalization computes statistics from the current batch rather than using the accumulated running stats it built up during training. For a batch of one or two test samples, those statistics are garbage. Your loss looks fine during the training loop but mysteriously degrades when you evaluate on the test set, and the bug is invisible because the code runs without errors. Always wrap evaluation in model.eval() and torch.no_grad().
A related mistake is forgetting torch.no_grad() during inference. Even when batch norm is not an issue, skipping no_grad means PyTorch is building a computation graph for every forward pass you run, allocating memory for gradient tensors that you will never use. On large models or long evaluation loops, this can exhaust your GPU memory or simply slow your evaluation down by two to three times unnecessarily.
Misusing Python lists instead of nn.ModuleList is the next trap. It is natural to write self.layers = [nn.Linear(64, 64) for _ in range(10)] when you want a list of layers. But a regular Python list is invisible to nn.Module's parameter registration system. None of those ten linear layers will appear in model.parameters(), which means the optimizer never updates them, your "ten layer network" is effectively just a single passthrough. Replace any Python list of modules with nn.ModuleList and any Python dict of modules with nn.ModuleDict.
Applying activation functions in the wrong place causes subtle training instability. The pattern Linear -> BatchNorm -> ReLU works better than Linear -> ReLU -> BatchNorm because batch norm operates best on the raw pre-activation values. Similarly, applying a final activation like sigmoid or softmax inside forward when your loss function already expects raw logits, BCEWithLogitsLoss and CrossEntropyLoss both apply their own internal softmax, leads to double-application of the non-linearity and terrible numerical stability. Read your loss function documentation before deciding what comes last in forward.
Finally, shape mismatches are the most common beginner error and the one with the least helpful error messages. PyTorch will tell you that tensors have incompatible shapes, but not always which layers caused the mismatch. The fix is methodical: add print(x.shape) at each step in forward during development, or use torchsummary before training to verify shapes propagate correctly. Do not guess, trace the shapes explicitly until you have the architecture working, then remove the debug prints.
Saving and Loading Models: state_dict() vs Full Model
You've trained a model for hours and it's performing great. Now you need to save it. PyTorch gives you two options, pick wisely.
Option 1: state_dict() (Recommended)
Save only the weights and biases. The principle here is separation of concerns: your Python code defines the architecture, and the checkpoint file stores only the learned values. This makes checkpoints portable across code refactors, smaller on disk, and safer to share with collaborators who might have slightly different Python environments:
# Save
torch.save(model.state_dict(), 'model_weights.pth')
# Load
model = DeepMLP() # Create a fresh model
model.load_state_dict(torch.load('model_weights.pth'))Why is this better? Because you always have the code to define your model architecture. The checkpoint only needs to store the learned parameters. This makes your saved files smaller and more portable.
Option 2: Full Model (Use with Caution)
Save the entire model object. This approach serializes both the weights and the class definition using Python's pickle mechanism, which means loading it requires the exact same Python environment, the exact same module paths, and the exact same class names to be importable. One refactor later and your saved model is unloadable:
# Save
torch.save(model, 'full_model.pth')
# Load
model = torch.load('full_model.pth')This works, but it's fragile. If you change your code (even slightly), loading might fail. Plus, the saved file is larger because it includes the class definition serialized as Python bytecode. Avoid this unless you have a specific reason.
Moving Models to GPU
Training on GPU is dramatically faster. PyTorch makes this trivial:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Move model to GPU
model = model.to(device)
# Now move your data too
x = torch.randn(32, 784).to(device)
output = model(x)That's it. Every parameter and buffer automatically moves to the device you specify. If you later need to move back to CPU:
model = model.to('cpu')DataParallel: Multi-GPU Training (Optional)
If you have multiple GPUs, you can parallelize training. The tradeoff is that DataParallel has a known overhead: it scatters batches to each GPU and gathers results back to the primary GPU on every forward pass, which creates a bottleneck on the gathering step. For serious distributed training, DistributedDataParallel is more efficient, but it requires more setup and is worth a dedicated article:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model = model.to(device)DataParallel automatically splits batches across your available GPUs and gathers results. The actual training loop doesn't change, PyTorch handles the distribution internally. (Note: there's also DistributedDataParallel for more serious multi-machine setups, but that's a topic for another article.)
Putting It All Together: A Complete Training Example
Let's wire everything together to show how nn.Module fits into a real training loop. Notice how much of the training loop is not about the network architecture at all, it is about moving data, calling the right mode-switching methods, and following the backward-zero_grad-step pattern correctly. Once you have that skeleton down, swapping in a different nn.Module subclass is trivial:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
# Create a simple dataset
X = torch.randn(1000, 784)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeepMLP().to(device)
# Initialize weights
def init_kaiming(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
model.apply(init_kaiming)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
model.train() # Set to training mode (batch norm, dropout active)
total_loss = 0
for x_batch, y_batch in loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
# Forward pass
outputs = model(x_batch)
loss = criterion(outputs, y_batch)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
# Save the model
torch.save(model.state_dict(), 'trained_model.pth')
# Evaluation mode (batch norm uses running stats, dropout is off)
model.eval()
with torch.no_grad():
test_x = torch.randn(100, 784).to(device)
test_output = model(test_x)
print(test_output.shape)See how nn.Module simplifies everything? Parameter tracking, device management, batch norm behavior switching, it all just works.
Transitioning Between Training and Evaluation Modes
This deserves its own section because it trips up a lot of people. Batch norm and dropout behave differently depending on whether you're training or evaluating. PyTorch handles this automatically, but you need to explicitly set the mode.
# Training mode: batch norm uses batch statistics, dropout is active
model.train()
output = model(x)
# Evaluation mode: batch norm uses running statistics, dropout is off
model.eval()
with torch.no_grad(): # Disables gradient computation (faster, saves memory)
output = model(x)Why torch.no_grad()? During evaluation, you don't need gradients, you're just making predictions. Skipping gradient computation saves memory and runs faster.
Here's a subtle point that matters: batch norm tracks running statistics during training. These are exponential moving averages of the mean and variance it sees in each batch. When you switch to evaluation mode, it uses these running statistics instead of computing new statistics from the current batch. This is crucial for evaluation, batch statistics from a small test batch are often garbage.
# Example showing the difference
model = nn.Sequential(
nn.Linear(10, 64),
nn.BatchNorm1d(64),
nn.ReLU()
)
# During training, batch norm tracks running stats
model.train()
for epoch in range(5):
x = torch.randn(32, 10)
output = model(x)
# Now in evaluation, it uses the accumulated running stats
model.eval()
x_test = torch.randn(1, 10)
output = model(x_test) # Uses the running statistics, not computed from 1 sampleIf you forget to call model.eval(), you'll get inconsistent results because batch norm is computing statistics from small batches instead of using the accumulated running stats.
Common Pitfalls and How to Avoid Them
Pitfall 1: Forgetting to Call super().init()
# WRONG
class BadNet(nn.Module):
def __init__(self):
# Forgot super().__init__()!
self.fc1 = nn.Linear(10, 5)
# RIGHT
class GoodNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)Without super().__init__(), PyTorch never registers your layers. This means parameters aren't tracked, device transfers fail, and your model breaks in mysterious ways.
Pitfall 2: Using Python Lists Instead of nn.ModuleList
# WRONG: PyTorch doesn't see these modules
self.layers = [nn.Linear(10, 5) for _ in range(3)]
# RIGHT: PyTorch registers all modules
self.layers = nn.ModuleList([nn.Linear(10, 5) for _ in range(3)])When you use a regular Python list, PyTorch's parameter registration completely misses it. All the parameters inside those modules are invisible to optimizers.
Pitfall 3: Activations in the Wrong Place
# WRONG: ReLU in __init__ makes it a "layer"
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 64)
self.relu = nn.ReLU() # This works, but...
# BETTER: Use torch.nn.functional for non-parameterized functions
import torch.nn.functional as F
def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Stateless, no parameters tracked
return xBoth approaches work, but the second is more common and slightly cleaner for activation functions that don't have parameters.
Pitfall 4: Not Setting model.eval() Before Evaluation
# WRONG: Getting training-mode statistics on test data
model.train() # or just forgetting to change mode
predictions = model(test_data)
# RIGHT
model.eval()
with torch.no_grad():
predictions = model(test_data)This causes subtle bugs where your model reports good validation metrics during training but terrible ones after you reload the checkpoint.
Advanced Pattern: Conditional Layers
Sometimes you want layers to be conditionally active based on a hyperparameter. Here's a useful pattern. The conditional layer pattern is especially powerful during ablation studies, the systematic process of removing components one at a time to understand what actually contributes to your model's performance. With this pattern, you can run experiments that isolate batch norm and dropout by simply changing constructor arguments rather than maintaining separate class definitions:
class ConditionalNet(nn.Module):
def __init__(self, use_batch_norm=True, use_dropout=True, dropout_rate=0.5):
super().__init__()
self.use_batch_norm = use_batch_norm
self.use_dropout = use_dropout
self.fc1 = nn.Linear(784, 256)
if self.use_batch_norm:
self.bn1 = nn.BatchNorm1d(256)
self.dropout = nn.Dropout(dropout_rate) if self.use_dropout else None
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
if self.use_batch_norm:
x = self.bn1(x)
x = torch.relu(x)
if self.use_dropout:
x = self.dropout(x)
x = self.fc2(x)
return xThis lets you easily experiment with different architectural choices without maintaining separate classes.
Inspecting Gradients
Before you train, it's useful to verify that gradients are flowing where you expect them. You can inspect gradients right after the backward pass. Gradient inspection is one of the most powerful debugging techniques available to you, it lets you see whether your architecture is actually learning or whether some layers are getting zero gradient (dead neurons) or enormous gradient (exploding gradients) before you commit to a long training run:
model = ShallowMLP()
x = torch.randn(32, 784)
y = torch.randint(0, 10, (32,))
output = model(x)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
# Check gradient statistics
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad mean={param.grad.mean():.4f}, grad std={param.grad.std():.4f}")Healthy gradients should have reasonable magnitudes, not exploding (huge values) or vanishing (tiny values). If you see either, something's wrong with your initialization or architecture.
Key Takeaways
-
Always subclass
nn.Modulefor any network you build. It gives you parameter tracking, device management, and proper gradient flow. -
__init__defines layers,forward()uses them. This separation keeps your code clean and composable. -
Use
nn.Sequentialfor simple chains and custom forward passes for anything more complex. -
Batch norm and dropout are training tools. They're automatically disabled when you call
model.eval(). -
Initialize weights smartly. Kaiming initialization for ReLU networks is your default choice.
-
Save
state_dict(), not the full model. It's smaller, more portable, and less error-prone. -
Nesting modules works beautifully. Use
nn.ModuleListfor lists of submodules, not regular Python lists. -
Always call
model.eval()before evaluation, and usetorch.no_grad()to skip unnecessary gradient computation. -
Remember
super().__init__(), it's not optional. Without it, your network breaks. -
Inspect gradients and parameters during development. You'll catch bugs early and understand what's happening inside your model.
You now have the full picture of nn.Module, from the internal mechanics of parameter registration to the architectural decisions that determine what your network can learn. We have covered the three-method pattern that every network follows, the six layer types that cover ninety percent of real architectures, the right initialization scheme for your activation function, gradient inspection for catching training pathologies early, and the exact mistakes that cost engineers hours of debugging time.
The consistent thread through all of it is that nn.Module is a framework for expressing intent clearly. When you write self.blocks = nn.ModuleList([ResidualBlock(...) for _ in range(num_blocks)]), you are not just organizing code, you are making your architectural thinking explicit and legible to anyone who reads it, including yourself six months from now. That clarity pays compounding dividends as your models grow more complex.
Next up: putting these networks to work with loss functions, optimizers, and proper training loops. Understanding nn.Module was the foundation, now we get to watch it actually learn.