Skip to content

2. WandB

WandB (short for Weights & Biases) is a powerful tool for tracking machine learning experiments, visualizing metrics, and managing models. It’s a platform designed to help data scientists and machine learning engineers with experiment management, collaboration, and reporting. It’s a tool for ML-Ops.

1. What is WandB?

WandB is a cloud-based platform that helps you:

  1. Track Experiments: Log hyperparameters, metrics, model weights, and even datasets.
  2. Visualize Results: Plot learning curves, confusion matrices, and other metrics.
  3. Collaborate: Share experiment results with teammates, compare experiments, and monitor ongoing runs.
  4. Hyperparameter Tuning: Integrate with popular libraries like Optuna, Ray Tune, and others for automated hyperparameter optimization.
  5. Version Control: Store datasets, models, and code versions for reproducibility.

2. Setting Up WandB

1. Install WandB

To get started, install the wandb Python package:

Terminal window
pip install wandb

2. Sign Up & Get API Key

  • Sign up on wandb.ai
  • After signing up, get your API key from here.

3. Log In Using API Key

Run the following command in your terminal to log in:

Terminal window
wandb login <your_api_key>

3. Basic Workflow

3.1 Initialize WandB in Your Script

Add wandb.init() at the start of your script to start tracking an experiment. Here’s how you can do that:

import wandb
# Initialize a new run (experiment)
wandb.init(project="my-first-project", entity="my-username") # project is the name of the experiment
  • project: Your project name (can be any name, it will show up in your WandB dashboard).
  • entity: Your username or team name on WandB (can be left out for personal projects).

3.2 Log Hyperparameters

You can log hyperparameters as key-value pairs by using wandb.config. This is a great way to keep track of different experiment settings.

wandb.init(project="my-first-project", entity="my-username")
# Log hyperparameters
wandb.config.batch_size = 32
wandb.config.lr = 0.001
wandb.config.epochs = 10

3.3 Log Metrics During Training

Log metrics like loss, accuracy, and other performance indicators during training. You can use wandb.log() to log these metrics after each epoch.

for epoch in range(epochs):
# Your training code here...
train_loss = 0.1 # Replace with actual loss
val_loss = 0.08 # Replace with actual validation loss
val_accuracy = 0.9 # Replace with actual validation accuracy
# Log metrics to WandB
wandb.log({"train_loss": train_loss, "val_loss": val_loss, "val_accuracy": val_accuracy})

WandB will automatically track your metrics and display them in a neat dashboard.

3.4 Log Model Weights

You can log your model weights at any point using wandb.watch(). This is useful for tracking gradients and weights, as well as for visualizations.

# If using PyTorch
wandb.watch(model, log="all") # logs weights, gradients, and model architecture
# If using TensorFlow
wandb.watch(model, log="all", log_freq=100) # Same for TensorFlow
  • log="all": Logs weights and gradients.
  • log_freq=100: Logs every 100th step.

3.5 Save and Log Your Model

Once you finish training, you can save and log the model to WandB to store it for future use.

wandb.save("model.pth") # Save and upload your model

3.6 Visualizing Results

  • WandB Dashboard: After running the script, your metrics, hyperparameters, and graphs will appear on your WandB dashboard. This allows you to visualize:

    • Loss curves: track training and validation loss over time.
    • Hyperparameter comparison: compare different runs of the same model with different hyperparameters.
    • Model weights/gradients: track how weights change during training.
  • Custom Visualizations: You can create your own custom plots (e.g., confusion matrix, feature importance) and log them using wandb.log().

Example:

wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(probs, targets)})

4. Advanced Features

4.1 Hyperparameter Optimization

You can use WandB Sweeps to automatically search for the best hyperparameters. Here’s a quick overview of how to set it up:

  • Define a sweep configuration (search space for hyperparameters):
method: grid # or 'random', 'bayesian'
name: my-sweep
metric:
name: val_loss
goal: minimize
parameters:
batch_size:
values: [16, 32, 64]
lr:
max: 0.1
min: 0.0001
  • Launch the sweep from your terminal:
Terminal window
wandb sweep sweep_config.yaml
  • Execute the sweep in your code:
sweep_id = wandb.sweep(sweep_config, project="my-sweep-project")
wandb.agent(sweep_id, function=train, count=10)

This will run 10 experiments, optimizing the batch_size and lr parameters based on validation loss.

4.2 Collaboration and Sharing

WandB allows you to share results with teammates and easily collaborate. You can:

  • Share links to your project and results.
  • Compare runs side-by-side (e.g., compare different models or hyperparameters).
  • Comment and annotate experiments for collaboration.

5. Example

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import wandb
# Initialize WandB
wandb.init(project="my-pytorch-project", entity="my-username")
# Data and model setup
transform = transforms.Compose([transforms.ToTensor()])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
wandb.watch(model, log="all")
# Training loop
for epoch in range(5):
for i, (inputs, labels) in enumerate(trainloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Log loss every 100 steps
if i % 100 == 99:
wandb.log({"train_loss": loss.item()})
# Log metrics after each epoch
wandb.log({"epoch": epoch, "train_loss": loss.item()})
wandb.save("model.pth") # Save model after training
wandb.finish() # Close the run

5.1 Useful WandB Commands

  • wandb login: Log into your WandB account.
  • wandb init: Initialize a new run.
  • wandb.save(): Save files like models, logs, etc.
  • wandb.finish(): End the current run.
  • wandb.watch(model): Track model gradients and weights.