Skip to content

1. Debugging

Training neural networks is powerful but full of potential pitfalls. Here’s a comprehensive guide to common training problems and how to debug/fix

1. πŸ” Common Problems

1.1 πŸ“‰ Training Loss Issues

Training loss not decreasing.

πŸ” Cause:

  • Learning rate too high or too low
  • Model underfitting
  • Poor initialization

πŸ› οΈ Fix:

  • Try different learning rates (use learning rate schedulers or sweeps)
  • Use a deeper or wider network
  • Use better weight initialization (e.g., Xavier, He)

1.2 🧠 Model Overfitting

Model overfitting (training loss ↓, val loss ↑).

πŸ” Cause:

  • Too complex model
  • Not enough data
  • No regularization

πŸ› οΈ Fix:

  • Add regularization: L2 (weight decay), dropout, batch norm
  • Use data augmentation
  • Reduce network size
  • Early stopping

1.3 🐒 Unstable Training

Very slow or unstable training.

πŸ” Cause:

  • Learning rate too low or exploding gradients
  • Poor normalization
  • Large input values

πŸ› οΈ Fix:

  • Normalize input features (e.g., z-score)
  • Use batch normalization
  • Use optimizers like Adam
  • Gradient clipping if exploding gradients occur

1.4 πŸ” Validation Accuracy Stuck

Validation accuracy stuck / poor generalization.

πŸ” Cause:

  • Underfitting or dataset bias
  • Noisy labels
  • Inappropriate architecture

πŸ› οΈ Fix:

  • Increase model capacity
  • Clean or relabel data
  • Try a more suitable model (e.g., CNN for images, RNN/Transformer for sequences)

1.5 πŸ“Š Unstable or NaN Loss

During training loss become NaN or unstable.

πŸ” Cause:

  • Exploding gradients
  • Division by 0 (e.g., log(0))
  • Bad learning rate or activation

πŸ› οΈ Fix:

  • Check for log(0) or large softmax logits
  • Use torch.nn.CrossEntropyLoss instead of manually combining softmax + NLLLoss
  • Switch from ReLU to LeakyReLU or ELU
  • Clip gradients

1.6 βš–οΈ Imbalanced Dataset

Imbalanced dataset (e.g., 95% of one class)

πŸ” Cause:

  • Class imbalance β†’ biased predictions

πŸ› οΈ Fix:

  • Use class weights in loss function
  • Try oversampling the minority class or undersampling the majority
  • Use metrics like F1-score or AUC instead of accuracy

1.7 πŸ‘οΈβ€πŸ—¨οΈ Model Doesn’t Learn

Model doesn’t learn at all (flat loss/acc).

πŸ” Cause:

  • Wrong labels
  • Model too small
  • Data pipeline bugs (e.g., all zeros)

πŸ› οΈ Fix:

  • Check labels (plot samples with labels!)
  • Overfit on a small batch first to test correctness
  • Print/log outputs to verify model is changing

2. πŸ› οΈ Debugging Strategies

2.1 βœ… What to Check?

CheckTip
Loss curve behaviorPlot training/val loss
GradientsCheck for zero or exploding grads
Model outputsPrint early predictions
DataVisualize input samples and labels
Overfit 1 batchYour model should fit it 100%
LoggingUse TensorBoard / wandb for tracking

2.2 πŸ”„ Tools That Help

  • TensorBoard / wandb – for tracking metrics visually
  • PyTorch hooks – to inspect gradients and activations
  • Gradient checking – for custom backward code
  • Unit tests for each component (model, loss, data loader)

2.3 βœ… NN Debug Checklist

1. πŸ“¦ Before Training

  • Is the dataset loaded correctly?
  • Are inputs normalized (e.g., mean=0, std=1 or [0,1] range)?
  • Are labels correct and balanced?
  • Is the model architecture suitable for the task?
  • Can the model overfit a tiny batch?

2. βš™οΈ During Training

  • Is the loss decreasing steadily?
  • Are gradients not NaN or exploding?
  • Are learning rate and optimizer appropriate (Adam, SGD, etc.)?
  • Are training and validation loss both tracked?
  • Are you tracking accuracy/F1 (not just loss)?

3. πŸ” Debugging Signs

SymptomLikely CauseSuggested Fix
Loss is flatData/label issue, learning rate too lowCheck input/labels, increase LR
Loss = NaNExploding gradients, bad numericsClip gradients, switch activation/loss
High training, low val accOverfittingAdd dropout, regularization, more data
Training loss doesn’t decreaseModel too small, poor initIncrease capacity, try different init

2.4 πŸ“‹ PyTorch Template

PyTorch Template with Logging (via TensorBoard)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Data
transform = transforms.Compose([transforms.ToTensor()])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
# Model
class SimpleNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28*28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten
x = self.relu(self.fc1(x))
return self.fc2(x)
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Logging
writer = SummaryWriter()
# Training loop
for epoch in range(5):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.4f}")
writer.add_scalar("Loss/train", running_loss / 100, epoch * len(trainloader) + i)
running_loss = 0.0
writer.close()

πŸ”Ž What this does:

  • Logs training loss to TensorBoard
  • Helps identify loss trends, spikes, plateaus