You are training your latest AI model, anxiously watching as the loss steadily decreases when suddenly — boom! Your logs are flooded with NaNs (Not a Number) — your model is irreparably corrupted and you’re left staring at your screen in despair. To make matters worse, the NaNs don’t appear consistently. Sometimes your model trains just fine; other times, it fails inexplicably. Sometimes it will crash immediately, sometimes after many days of training.
NaNs in Deep Learning workloads are amongst the most frustrating issues to encounter. And because they often appear sporadically — triggered by a specific combination of model state, input data, and stochastic factors — they can be incredibly difficult to reproduce and debug.
Given the considerable cost of training AI models and the potential waste caused by NaN failures, it is recommended to have dedicated tools for capturing and analyzing NaN occurrences. In a previous post, we discussed the challenge of debugging NaNs in a TensorFlow training workload. We proposed an efficient scheme for capturing and reproducing NaNs and shared a sample TensorFlow implementation. In this post, we adopt and demonstrate a similar mechanism for debugging NaNs in PyTorch workloads. The general scheme is as follows:
On each training step:
- Save a copy of the training input batch.
- Check the gradients for NaN values. If any appear, save a checkpoint with the current model weights before the model is corrupted. Also, save the input batch and, if necessary, the stochastic state. Discontinue the training job.
- Reproduce and debug the NaN occurrence by loading the saved experiment state.
Although this scheme can be easily implemented in native PyTorch, we will take the opportunity to demonstrate some of the conveniences of PyTorch Lightning — a powerful open-source framework designed to streamline the development of machine learning (ML) models. Built on PyTorch, Lightning abstracts away many of the boiler-plate components of an ML experiment, such as training loops, data distribution, logging, and more, enabling developers to focus on the core logic of their models.
To implement our NaN capturing scheme, we will use Lightning’s callback interface — a dedicated structure that enables inserting custom logic at specific points during the flow of execution.
Importantly, please do not view our choice of Lightning or any other tool or technique that we mention as an endorsement of its use. The code that we will share is intended for demonstrative purposes — please do not rely on its correctness or optimality.
Many thanks to Rom Maltser for his contributions to this post.
NaNCapture Callback
To implement our NaN capturing solution, we create a NaNCapture Lightning callback. The constructor receives a directory path for storing/loading checkpoints and sets up the NaNCapture state. We also define utilities for checking for NaNs, storing checkpoints, and halting the training job.
import os
import torch
from copy import deepcopy
import lightning.pytorch as pl
class NaNCapture(pl.Callback):
def __init__(self, dirpath: str):
# path to checkpoint
self.dirpath = dirpath
# update to True when Nan is identified
self.nan_captured = False
# stores a copy of the last batch
self.last_batch = None
self.batch_idx = None
@staticmethod
def contains_nan(tensor):
return torch.isnan(tensor).any().item()
# alternatively check for finite
# return not torch.isfinite(tensor).item()
@staticmethod
def halt_training(trainer):
trainer.should_stop = True
# communicate stop command to all other ranks
trainer.strategy.reduce_boolean_decision(trainer.should_stop,
all=False)
def save_ckpt(self, trainer):
os.makedirs(self.dirpath, exist_ok=True)
# include trainer.global_rank to avoid conflict
filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
full_path = os.path.join(self.dirpath, filename)
print(f"saving ckpt to {full_path}")
trainer.save_checkpoint(full_path, False)
Callback Function: on_train_batch_start
We begin by implementing the on_train_batch_start hook to store a copy of each input batch. In case of a NaN event, this batch will be stored in the checkpoint.
Callback Function: on_before_optimizer_step
Next we implement the on_before_optimizer_step hook. Here, we check for NaN entries in all of the gradient tensors. If found, we store a checkpoint with the uncorrupted model weights and halt the training.
Python"> def on_before_optimizer_step(self, trainer, pl_module, optimizer):
if not self.nan_captured:
# Check if gradients contain NaN
grads = [p.grad.view(-1) for p in pl_module.parameters()
if p.grad is not None]
all_grads = torch.cat(grads)
if self.contains_nan(all_grads):
print("nan found")
self.save_ckpt(trainer)
self.halt_training(trainer)
Capturing the Training State
To enable reproducibility, we include the NaNCapture state in the checkpoint by appending it to the training state dictionary. Lightning provides dedicated utilities for saving and loading a callback state:
def state_dict(self):
d = {"nan_captured": self.nan_captured}
if self.nan_captured:
d["last_batch"] = self.last_batch
return d
def load_state_dict(self, state_dict):
self.nan_captured = state_dict.get("nan_captured", False)
if self.nan_captured:
self.last_batch = state_dict["last_batch"]
Reproducing the NaN Occurrence
We have described how our NaNCapture callback can be used to store the training state that resulted in a NaN, but how do we reload this state in order to reproduce the issue and debug it? To accomplish this, we leverage Lightning’s dedicated data loading class, LightningDataModule.
DataModule Function: on_before_batch_transfer
In the code block below, we extend the LightningDataModule class to allow injecting a fixed training input batch. This is achieved by overriding the on_before_batch_transfer hook, as shown below:
from lightning.pytorch import LightningDataModule
class InjectableDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.cached_batch = None
def set_custom_batch(self, batch):
self.cached_batch = batch
def on_before_batch_transfer(self, batch, dataloader_idx):
if self.cached_batch:
return self.cached_batch
return batch
Callback Function: on_train_start
The final step is modifying the on_train_start hook of our NaNCapture callback to inject the stored training batch into the LightningDataModule.
def on_train_start(self, trainer, pl_module):
if self.nan_captured:
datamodule = trainer.datamodule
datamodule.set_custom_batch(self.last_batch)
In the next section we will demonstrate the end-to-end solution using a toy example.
Toy Example
To test our new callback, we create a resnet50-based image classification model with a loss function deliberately designed to trigger NaN occurrences.
Instead of using the standard CrossEntropy loss, we compute binary_cross_entropy_with_logits for each class independently and divide the result by the number of samples belonging to that class. Inevitably, we will encounter a batch in which one or more classes are missing, leading to a divide-by-zero operation, resulting in NaN values and corrupting the model.
The implementation below follows Lightning’s introductory tutorial.
import lightning.pytorch as pl
import torch
import torchvision
import torch.nn.functional as F
num_classes = 20
# define a lightning module
class ResnetModel(pl.LightningModule):
def __init__(self):
"""Initializes a new instance of the MNISTModel class."""
super().__init__()
self.model = torchvision.models.resnet50(num_classes=num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_nb):
x, y = batch
outputs = self(x)
# uncomment for default loss
# return F.cross_entropy(outputs, y)
# calculate binary_cross_entropy for each class individually
losses = []
for c in range(num_classes):
count = torch.count_nonzero(y==c)
masked = torch.where(y==c, 1., 0.)
loss = F.binary_cross_entropy_with_logits(
outputs[..., c],
masked,
reduction='sum'
)
mean_loss = loss/count # could result in NaN
losses.append(mean_loss)
total_loss = torch.stack(losses).mean()
return total_loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
We define a synthetic dataset and encapsulate it in our InjectableDataModule
class:
import os
import random
from torch.utils.data import Dataset, DataLoader
batch_size = 128
num_steps = 800
# A dataset with random images and labels
class FakeDataset(Dataset):
def __len__(self):
return batch_size*num_steps
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(random.randint(0, num_classes-1),
dtype=torch.int64)
return rand_image, label
# define a lightning datamodule
class FakeDataModule(InjectableDataModule):
def train_dataloader(self):
dataset = FakeDataset()
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=os.cpu_count(),
pin_memory=True
)
Finally, we initialize a Lightning Trainer with our NaNCapture callback and call trainer.fit with our Lightning module and Lightning DataModule.
import time
if __name__ == "__main__":
# Initialize a lightning module
lit_module = ResnetModel()
# Initialize a DataModule
mnist_data = FakeDataModule()
# Train the model
ckpt_dir = "./ckpt_dir"
trainer = pl.Trainer(
max_epochs=1,
callbacks=[NaNCapture(ckpt_dir)]
)
ckpt_path = None
# check is nan ckpt exists
if os.path.isdir(ckpt_dir):
# check if nan ckpt exists
if os.path.isdir(ckpt_dir):
dir_contents = [os.path.join(ckpt_dir, f)
for f in os.listdir(ckpt_dir)]
ckpts = [f for f in dir_contents
if os.path.isfile(f) and f.endswith('.ckpt')]
if ckpts:
ckpt_path = ckpts[0]
t0 = time.perf_counter()
trainer.fit(lit_module, mnist_data, ckpt_path=ckpt_path)
print(f"total runtime: {time.perf_counter() - t0}")
After a number of training steps, a NaN event will occur. At this point a checkpoint is saved with the full training state and the training is halted.
When the script is run again the exact state that caused the NaN will be reloaded allowing us to easily reproduce the issue and debug its root cause.
Performance Overhead
To assess the impact of our NaNCapture callback on runtime performance, we modified our experiment to use CrossEntropyLoss (to avoid NaNs) and measured the average throughput when running with and without NaNCapture callback. The experiments were conducted on an NVIDIA L40S GPU, with a PyTorch 2.5.1 Docker image.

For our toy model, the NaNCapture callback adds a minimal 1.5% overhead to the runtime performance — a small price to pay for the valuable debugging capabilities it provides.
Naturally, the actual overhead will depend on the specifics of the model and runtime environment.
How to Handle Stochasticity
The solution we have described henceforth will succeed in reproducing the training state provided that the model does not include any randomness. However, introducing stochasticity into the model definition is often critical for convergence. A common example of a stochastic layer is torch.nn.Dropout.
You may find that your NaN event depends on the precise state of randomness when the failure occurred. Consequently, we would like to enhance our NaNCapture callback to capture and restore the random state at the point of failure. The random state is determined by a number of libraries. In the code block below, we attempt to capture the full state of randomness:
import os
import torch
import random
import numpy as np
from copy import deepcopy
import lightning.pytorch as pl
class NaNCapture(pl.Callback):
def __init__(self, dirpath: str):
# path to checkpoint
self.dirpath = dirpath
# update to True when Nan is identified
self.nan_captured = False
# stores a copy of the last batch
self.last_batch = None
self.batch_idx = None
# rng state
self.rng_state = {
"torch": None,
"torch_cuda": None,
"numpy": None,
"random": None
}
@staticmethod
def contains_nan(tensor):
return torch.isnan(tensor).any().item()
# alternatively check for finite
# return not torch.isfinite(tensor).item()
@staticmethod
def halt_training(trainer):
trainer.should_stop = True
trainer.strategy.reduce_boolean_decision(trainer.should_stop,
all=False)
def save_ckpt(self, trainer):
os.makedirs(self.dirpath, exist_ok=True)
# include trainer.global_rank to avoid conflict
filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
full_path = os.path.join(self.dirpath, filename)
print(f"saving ckpt to {full_path}")
trainer.save_checkpoint(full_path, False)
def on_train_start(self, trainer, pl_module):
if self.nan_captured:
# inject batch
datamodule = trainer.datamodule
datamodule.set_custom_batch(self.last_batch)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if self.nan_captured:
# restore random state
torch.random.set_rng_state(self.rng_state["torch"])
torch.cuda.set_rng_state_all(self.rng_state["torch_cuda"])
np.random.set_state(self.rng_state["numpy"])
random.setstate(self.rng_state["random"])
else:
# capture current batch
self.last_batch= deepcopy(batch)
self.batch_idx = batch_idx
# capture current random state
self.rng_state["torch"] = torch.random.get_rng_state()
self.rng_state["torch_cuda"] = torch.cuda.get_rng_state_all()
self.rng_state["numpy"] = np.random.get_state()
self.rng_state["random"] = random.getstate()
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
if not self.nan_captured:
# Check if gradients contain NaN
grads = [p.grad.view(-1) for p in pl_module.parameters()
if p.grad is not None]
all_grads = torch.cat(grads)
if self.contains_nan(all_grads):
print("nan found")
self.save_ckpt(trainer)
self.halt_training(trainer)
def state_dict(self):
d = {"nan_captured": self.nan_captured}
if self.nan_captured:
d["last_batch"] = self.last_batch
d["rng_state"] = self.rng_state
return d
def load_state_dict(self, state_dict):
self.nan_captured = state_dict.get("nan_captured", False)
if self.nan_captured:
self.last_batch = state_dict["last_batch"]
self.rng_state = state_dict["rng_state"]
Importantly, setting the random state may not guarantee full reproducibility. The GPU owes its power to its massive parallelism. In some GPU operations, multiple threads may read or write concurrently to the same memory locations resulting in nondeterminism. PyTorch allows for some control over this via its use_deterministic_algorithms, but this may impact the runtime performance. Additionally, there is a possibility that the NaN event will not reproduced once this configuration setting is changed. Please see the PyTorch documentation on reproducibility for more details.
Summary
Encountering NaN failures is one of the most discouraging events that can happen in machine learning development. These errors not only waste valuable computation and development resources, but often indicate fundamental issues in the model architecture or experiment design. Due to their sporadic, sometimes elusive nature, debugging NaN failures can be a nightmare.
This post introduced a proactive approach for capturing and reproducing NaN errors using a dedicated Lightning callback. The solution we shared is a proposal which can be modified and extended for your specific use case.
While this solution may not address every possible NaN scenario, it significantly reduces debugging time when applicable, potentially saving developers countless hours of frustration and wasted effort.
The post Debugging the Dreaded NaN appeared first on Towards Data Science.
Capturing and reproducing failures in PyTorch training with Lightning
The post Debugging the Dreaded NaN appeared first on Towards Data Science. Data Science, Machine Learning, Deep Learning, Hands On Tutorials, Model Training, Python, Pytorch Lightning Towards Data ScienceRead More


0 Comments