1. Debugging
Training neural networks is powerful but full of potential pitfalls. Hereβs a comprehensive guide to common training problems and how to debug/fix
1. π Common Problems
1.1 π Training Loss Issues
Training loss not decreasing.
π Cause:
- Learning rate too high or too low
- Model underfitting
- Poor initialization
π οΈ Fix:
- Try different learning rates (use learning rate schedulers or sweeps)
- Use a deeper or wider network
- Use better weight initialization (e.g., Xavier, He)
1.2 π§ Model Overfitting
Model overfitting (training loss β, val loss β).
π Cause:
- Too complex model
- Not enough data
- No regularization
π οΈ Fix:
- Add regularization: L2 (weight decay), dropout, batch norm
- Use data augmentation
- Reduce network size
- Early stopping
1.3 π’ Unstable Training
Very slow or unstable training.
π Cause:
- Learning rate too low or exploding gradients
- Poor normalization
- Large input values
π οΈ Fix:
- Normalize input features (e.g., z-score)
- Use batch normalization
- Use optimizers like Adam
- Gradient clipping if exploding gradients occur
1.4 π Validation Accuracy Stuck
Validation accuracy stuck / poor generalization.
π Cause:
- Underfitting or dataset bias
- Noisy labels
- Inappropriate architecture
π οΈ Fix:
- Increase model capacity
- Clean or relabel data
- Try a more suitable model (e.g., CNN for images, RNN/Transformer for sequences)
1.5 π Unstable or NaN Loss
During training loss become NaN or unstable.
π Cause:
- Exploding gradients
- Division by 0 (e.g., log(0))
- Bad learning rate or activation
π οΈ Fix:
- Check for
log(0)
or large softmax logits - Use
torch.nn.CrossEntropyLoss
instead of manually combining softmax + NLLLoss - Switch from ReLU to LeakyReLU or ELU
- Clip gradients
1.6 βοΈ Imbalanced Dataset
Imbalanced dataset (e.g., 95% of one class)
π Cause:
- Class imbalance β biased predictions
π οΈ Fix:
- Use class weights in loss function
- Try oversampling the minority class or undersampling the majority
- Use metrics like F1-score or AUC instead of accuracy
1.7 ποΈβπ¨οΈ Model Doesnβt Learn
Model doesnβt learn at all (flat loss/acc).
π Cause:
- Wrong labels
- Model too small
- Data pipeline bugs (e.g., all zeros)
π οΈ Fix:
- Check labels (plot samples with labels!)
- Overfit on a small batch first to test correctness
- Print/log outputs to verify model is changing
2. π οΈ Debugging Strategies
2.1 β What to Check?
Check | Tip |
---|---|
Loss curve behavior | Plot training/val loss |
Gradients | Check for zero or exploding grads |
Model outputs | Print early predictions |
Data | Visualize input samples and labels |
Overfit 1 batch | Your model should fit it 100% |
Logging | Use TensorBoard / wandb for tracking |
2.2 π Tools That Help
- TensorBoard / wandb β for tracking metrics visually
- PyTorch hooks β to inspect gradients and activations
- Gradient checking β for custom backward code
- Unit tests for each component (model, loss, data loader)
2.3 β NN Debug Checklist
1. π¦ Before Training
- Is the dataset loaded correctly?
- Are inputs normalized (e.g., mean=0, std=1 or [0,1] range)?
- Are labels correct and balanced?
- Is the model architecture suitable for the task?
- Can the model overfit a tiny batch?
2. βοΈ During Training
- Is the loss decreasing steadily?
- Are gradients not NaN or exploding?
- Are learning rate and optimizer appropriate (Adam, SGD, etc.)?
- Are training and validation loss both tracked?
- Are you tracking accuracy/F1 (not just loss)?
3. π Debugging Signs
Symptom | Likely Cause | Suggested Fix |
---|---|---|
Loss is flat | Data/label issue, learning rate too low | Check input/labels, increase LR |
Loss = NaN | Exploding gradients, bad numerics | Clip gradients, switch activation/loss |
High training, low val acc | Overfitting | Add dropout, regularization, more data |
Training loss doesnβt decrease | Model too small, poor init | Increase capacity, try different init |
2.4 π PyTorch Template
PyTorch Template with Logging (via TensorBoard)
import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.tensorboard import SummaryWriterfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoader
# Datatransform = transforms.Compose([transforms.ToTensor()])trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
# Modelclass SimpleNN(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(28*28, 128) self.relu = nn.ReLU() self.fc2 = nn.Linear(128, 10)
def forward(self, x): x = x.view(x.size(0), -1) # flatten x = self.relu(self.fc1(x)) return self.fc2(x)
model = SimpleNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Loggingwriter = SummaryWriter()
# Training loopfor epoch in range(5): running_loss = 0.0 for i, (inputs, labels) in enumerate(trainloader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()
running_loss += loss.item() if i % 100 == 99: print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.4f}") writer.add_scalar("Loss/train", running_loss / 100, epoch * len(trainloader) + i) running_loss = 0.0
writer.close()
π What this does:
- Logs training loss to TensorBoard
- Helps identify loss trends, spikes, plateaus