2. WandB
WandB (short for Weights & Biases) is a powerful tool for tracking machine learning experiments, visualizing metrics, and managing models. It’s a platform designed to help data scientists and machine learning engineers with experiment management, collaboration, and reporting. It’s a tool for ML-Ops.
1. What is WandB?
WandB is a cloud-based platform that helps you:
- Track Experiments: Log hyperparameters, metrics, model weights, and even datasets.
- Visualize Results: Plot learning curves, confusion matrices, and other metrics.
- Collaborate: Share experiment results with teammates, compare experiments, and monitor ongoing runs.
- Hyperparameter Tuning: Integrate with popular libraries like Optuna, Ray Tune, and others for automated hyperparameter optimization.
- Version Control: Store datasets, models, and code versions for reproducibility.
2. Setting Up WandB
1. Install WandB
To get started, install the wandb
Python package:
pip install wandb
2. Sign Up & Get API Key
3. Log In Using API Key
Run the following command in your terminal to log in:
wandb login <your_api_key>
3. Basic Workflow
3.1 Initialize WandB in Your Script
Add wandb.init()
at the start of your script to start tracking an experiment. Here’s how you can do that:
import wandb
# Initialize a new run (experiment)wandb.init(project="my-first-project", entity="my-username") # project is the name of the experiment
project
: Your project name (can be any name, it will show up in your WandB dashboard).entity
: Your username or team name on WandB (can be left out for personal projects).
3.2 Log Hyperparameters
You can log hyperparameters as key-value pairs by using wandb.config
. This is a great way to keep track of different experiment settings.
wandb.init(project="my-first-project", entity="my-username")
# Log hyperparameterswandb.config.batch_size = 32wandb.config.lr = 0.001wandb.config.epochs = 10
3.3 Log Metrics During Training
Log metrics like loss, accuracy, and other performance indicators during training. You can use wandb.log()
to log these metrics after each epoch.
for epoch in range(epochs): # Your training code here... train_loss = 0.1 # Replace with actual loss val_loss = 0.08 # Replace with actual validation loss val_accuracy = 0.9 # Replace with actual validation accuracy
# Log metrics to WandB wandb.log({"train_loss": train_loss, "val_loss": val_loss, "val_accuracy": val_accuracy})
WandB will automatically track your metrics and display them in a neat dashboard.
3.4 Log Model Weights
You can log your model weights at any point using wandb.watch()
. This is useful for tracking gradients and weights, as well as for visualizations.
# If using PyTorchwandb.watch(model, log="all") # logs weights, gradients, and model architecture
# If using TensorFlowwandb.watch(model, log="all", log_freq=100) # Same for TensorFlow
log="all"
: Logs weights and gradients.log_freq=100
: Logs every 100th step.
3.5 Save and Log Your Model
Once you finish training, you can save and log the model to WandB to store it for future use.
wandb.save("model.pth") # Save and upload your model
3.6 Visualizing Results
-
WandB Dashboard: After running the script, your metrics, hyperparameters, and graphs will appear on your WandB dashboard. This allows you to visualize:
- Loss curves: track training and validation loss over time.
- Hyperparameter comparison: compare different runs of the same model with different hyperparameters.
- Model weights/gradients: track how weights change during training.
-
Custom Visualizations: You can create your own custom plots (e.g., confusion matrix, feature importance) and log them using
wandb.log()
.
Example:
wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(probs, targets)})
4. Advanced Features
4.1 Hyperparameter Optimization
You can use WandB Sweeps to automatically search for the best hyperparameters. Here’s a quick overview of how to set it up:
- Define a sweep configuration (search space for hyperparameters):
method: grid # or 'random', 'bayesian'name: my-sweepmetric: name: val_loss goal: minimizeparameters: batch_size: values: [16, 32, 64] lr: max: 0.1 min: 0.0001
- Launch the sweep from your terminal:
wandb sweep sweep_config.yaml
- Execute the sweep in your code:
sweep_id = wandb.sweep(sweep_config, project="my-sweep-project")wandb.agent(sweep_id, function=train, count=10)
This will run 10 experiments, optimizing the batch_size
and lr
parameters based on validation loss.
4.2 Collaboration and Sharing
WandB allows you to share results with teammates and easily collaborate. You can:
- Share links to your project and results.
- Compare runs side-by-side (e.g., compare different models or hyperparameters).
- Comment and annotate experiments for collaboration.
5. Example
import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transformsimport wandb
# Initialize WandBwandb.init(project="my-pytorch-project", entity="my-username")
# Data and model setuptransform = transforms.Compose([transforms.ToTensor()])trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
model = nn.Sequential( nn.Flatten(), nn.Linear(28*28, 128), nn.ReLU(), nn.Linear(128, 10))criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)
wandb.watch(model, log="all")
# Training loopfor epoch in range(5): for i, (inputs, labels) in enumerate(trainloader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()
# Log loss every 100 steps if i % 100 == 99: wandb.log({"train_loss": loss.item()})
# Log metrics after each epoch wandb.log({"epoch": epoch, "train_loss": loss.item()})
wandb.save("model.pth") # Save model after trainingwandb.finish() # Close the run
5.1 Useful WandB Commands
wandb login
: Log into your WandB account.wandb init
: Initialize a new run.wandb.save()
: Save files like models, logs, etc.wandb.finish()
: End the current run.wandb.watch(model)
: Track model gradients and weights.