Introduction
PyTorch is the go-to choice for researchers and practitioners for building deep learning models due to its flexibility, intuitive Pythonic API design, and ease of use.
It takes a programmer just three steps to create a deep learning model in PyTorch:
- First, we define a model class inherited from PyTorch’s
nn.Module
class

- Moving on, we declare all the network components (layers, dropout, batch norm, etc.) in the
__init__()
method:

- Finally, we define the forward pass of the neural network in the
forward()
method:

That’s it!
Once we have defined the network, we can proceed with training the model by declaring the optimizer, loss function, etc., without having to define the backward pass explicitly.
More specifically, one can define the training loop as demonstrated below and train the model easily:

Issues with PyTorch
As we saw above, defining the network was so simple and elegant, wasn’t it?
However, as our models grow more complex and larger, several challenges arise when using PyTorch:
#1) Managing training loops
With complex models, manually managing the training loop in PyTorch can become tedious.
This includes iterating over the dataset, performing forward and backward passes, and updating the model parameters, which, of course, is quite standardized as shown below, but it does need some review and maintainability.

#2) Logging
Logging is crucial for monitoring the training process and analyzing model performance.
PyTorch does not provide built-in support for logging, requiring users to implement their own logging solutions or integrate external logging frameworks.
#3) Handling distributed training
As models grow larger, training them on multiple GPUs or across multiple machines becomes necessary to reduce training time.

PyTorch provides support for distributed training, but the implementation can be complex, involving setting up processes, synchronizing gradients, and handling communication between processes.
#4) Debugging in a distributed setting
Debugging distributed training can be challenging due to the complexity of the setup and the potential for issues to arise from communication between processes.
#5) Mixed-precision training
Mixed-precision training, which involves using lower precision (e.g., half-precision floating-point numbers) for certain parts of the training process, can help reduce memory usage and speed up training.
PyTorch supports mixed-precision training, but managing the precision of different operations manually is pretty challenging.
#6) Running models on TPUs
PyTorch natively supports running models on GPUs, but running models on TPUs (Tensor Processing Units) requires additional setup and configuration.
From the above discussion, it’s clear that PyTorch doesn’t provide out-of-the-box solutions for many important tasks, leading to boilerplate code and increased chances of errors.
Of course, these challenges may not be major concern for all types of models, especially small or simple ones.
For small-scale projects, the overhead of managing training loops, logging, and distributed training may not outweigh the benefits of using PyTorch directly. However, as models grow in complexity and size, these challenges become more pronounced.
PyTorch Lightning
PyTorch Lightning resolves each of the above-discussed challenges with PyTorch.
You can think of PyTorch Lightning as a lightweight wrapper around PyTorch that abstracts away the boilerplate code, which we typically write with PyTorch, and makes the training process more streamlined and readable.
Just like Keras is a wrapper on TensorFlow, PyTorch lightning is a wrapper on PyTorch, but one that makes it much more efficient than the traditional way of training the model.

Thus, one can use ANY PyTorch model as a PyTorch Lightning model.
As the library is an optimized wrapper around PyTorch, the developers claim to reduce the repeated (boilerplate) code by 70-80%, which minimizes the surface area for bugs and lets us focus on delivering value instead of engineering.
Moreover, as we shall see ahead, with PyTorch Lightning, we can define our model and training logic in a clear and concise manner, which lets us focus more on the research and less on the implementation details.
In fact, the utility is pretty evident from its popularity because its GitHub repo has over 26k
stars:

Revisiting the challenges with PyTorch, we discussed above, here’s how PyTorch Lightning addresses them.
- Managing training loops: PyTorch Lightning simplifies this process by providing a high-level abstraction for defining the training loop, reducing the amount of boilerplate code required.
- Logging: PyTorch Lightning integrates with popular logging frameworks like TensorBoard and Comet, making it easier to log training metrics and visualize them in real-time.
- Handling distributed training: PyTorch Lightning simplifies distributed training by providing a unified interface. This abstracts away the complexity of the underlying implementation.
- Debugging in a distributed setting: PyTorch Lightning provides tools and utilities to facilitate debugging in a distributed setting, making it easier to identify and resolve issues.
- Mixed-precision training: PyTorch Lightning simplifies mixed-precision training by providing utilities to automatically handle the precision of operations based on user-defined settings.
- Running models on TPUs: PyTorch Lightning supports running models on TPUs, abstracting away the complexity of the underlying TPU architecture and allowing users to focus on their model implementation.
Along with that, one of the best things about PyTorch Lightning is that it has a minimal API. In most cases, the LightningModule
and Trainer
class are the only 2 APIs one must learn because the rest is just organized PyTorch.
If none of these things is clear yet, don’t worry. Let’s get into a complete walkthrough of using PyTorch Lightning.
Now that we understand what PyTorch Lightning is and the motivation to use it over PyTorch, let’s get into more details about its implementation and how PyTorch Lightning works.
More specifically:
- We shall begin with a standard PyTorch code, and learn how to convert that into PyTorch Lightning code.
- Next, we shall look at how we use the Trainer() class from PyTorch Lightning to simplify model training and define various methods for training, validation, testing and predicting. Here, we shall also learn how to log model training and integrate various performance metrics during training.
- Finally, we shall deep dive into the additional utilities offered by PyTorch Lightning like mixed precision training, callbacks, profiling code for optimization.
Let’s begin!
PyTorch to PyTorch Lightning
In this section, let’s build a simple neural network on the MNIST dataset using PyTorch. Then, we will see how we can convert that code to a PyTorch Lightning code.
PyTorch Model
Here are the traditional steps to building a model in PyTorch:
Step 1) Import required packages and libraries
First, we import the required packages from PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
Step 2) Load the dataset
Next, we load the MNIST dataset (train and test) and create their respective PyTorch dataloaders.
1) Data Transformer
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
2) Create Train Dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
3) Create Test Dataset
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
Step 3) Define the PyTorch Model
Moving on, we define a simple feedforward neural network architecture. This is demonstrated below:
class PyTorchNet(nn.Module):
def __init__(self):
super(PyTorchNet, self).__init__()
self.fc1 = nn.Linear(28*28, 512)
self.fc2 = nn.Linear(512, 1024)
self.fc3 = nn.Linear(1024, 128)
self.fc4 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = self.fc4(x)
return x
Step 4) Initialize the model and define the loss function and optimizer
Moving on, we shall initialize the model and define the loss function to train it — the CrossEntropyLoss
.
Initialize model
model = PyTorchNet()
Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
Define loss function
criterion = nn.CrossEntropyLoss()
Step 5) Define the evaluation method
To evaluate the model after every epoch, let’s define an evaluate()
method that will iterate over the examples in the testloader
and compute the accuracy. This is demonstrated below:
def evaluate(model):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for data in testloader:
inputs, labels = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
Step 6) Train the model
Now, we will train the PyTorch model.
for epoch in range(5):
model.train()
running_loss = 0.0
for data in trainloader:
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
accuracy = evaluate(model)
print(f"""Epoch {epoch + 1},
Loss: {running_loss / len(trainloader)},
Accuracy: {accuracy * 100:.2f}%
""")
With this, we are done with the PyTorch model.
Now, if we go back to the above code, there’s too much boilerplate code here.
Simply put, boilerplate means the repetitive and standardized sections of code that are necessary for the functioning of the program, but they are not unique to the model we are training. Instead, this is something that we would almost always write in most other projects too.
For instance, the accuracy()
method and the training loop contribute to the boilerplate code here.

While these boilerplate sections are essential for training a neural network, they can be cumbersome to write and maintain, and they are pretty repetitive as well.