Optimization: Weights & Biases
Hyperparameter tuning without tracking is guesswork. Weights & Biases makes experiments reproducible.
Why Track Experiments?
Without tracking:
- "Wait, which learning rate worked best?"
- "Did I already try that configuration?"
- "What were the settings for that good run?"
With tracking:
- Every experiment logged automatically
- Compare runs side-by-side
- Share results with your team
Setting Up wandb
import wandb
# Initialize (do this once per project)
wandb.init(
project="arena-mnist",
config={
"learning_rate": 0.001,
"batch_size": 64,
"epochs": 10,
"architecture": "ResNet34",
}
)
# Access config
config = wandb.config
print(f"Training with lr={config.learning_rate}")
Logging Metrics
for epoch in range(config.epochs):
for batch_idx, (x, y) in enumerate(train_loader):
loss = train_step(model, x, y)
# Log training metrics
wandb.log({
"train_loss": loss.item(),
"epoch": epoch,
"step": epoch * len(train_loader) + batch_idx,
})
# Log validation metrics after each epoch
val_loss, val_acc = evaluate(model, val_loader)
wandb.log({
"val_loss": val_loss,
"val_accuracy": val_acc,
"epoch": epoch,
})
# Finish the run
wandb.finish()
Using Config Objects
Clean separation of hyperparameters:
from dataclasses import dataclass
@dataclass
class TrainConfig:
learning_rate: float = 0.001
batch_size: int = 64
epochs: int = 10
weight_decay: float = 0.01
momentum: float = 0.9
optimizer: str = "adam"
def train(config: TrainConfig):
wandb.init(
project="arena-mnist",
config=vars(config) # Convert dataclass to dict
)
# ... training code using config ...
wandb.finish()
# Run with different configs
train(TrainConfig(learning_rate=0.001))
train(TrainConfig(learning_rate=0.01, optimizer="sgd"))
Hyperparameter Sweeps
Automated search over configurations:
sweep_config = {
"method": "bayes", # or "grid", "random"
"metric": {"name": "val_accuracy", "goal": "maximize"},
"parameters": {
"learning_rate": {
"min": 1e-5,
"max": 1e-2,
"distribution": "log_uniform_values"
},
"batch_size": {"values": [32, 64, 128]},
"weight_decay": {
"min": 0.001,
"max": 0.1,
"distribution": "log_uniform_values"
},
}
}
# Create sweep
sweep_id = wandb.sweep(sweep_config, project="arena-mnist")
# Run agent
def train_sweep():
wandb.init()
config = wandb.config
# Train with sweep config
model = train_model(
lr=config.learning_rate,
batch_size=config.batch_size,
weight_decay=config.weight_decay
)
wandb.agent(sweep_id, function=train_sweep, count=20)
Logging Images and Media
# Log images
images = x[:8] # First 8 images from batch
wandb.log({
"input_images": wandb.Image(images),
})
# Log predictions
predictions = model(images)
wandb.log({
"predictions": wandb.Histogram(predictions.detach().cpu().numpy()),
})
# Log model architecture
wandb.watch(model, log="all") # Logs gradients and parameters
Comparing Runs
In the wandb dashboard:
- Select multiple runs
- Click "Compare"
- View:
- Loss curves overlaid
- Config differences highlighted
- Final metrics table
# You can also do this programmatically
api = wandb.Api()
runs = api.runs("your-username/arena-mnist")
for run in runs:
print(f"{run.name}: {run.summary.get('val_accuracy', 'N/A')}")
Best Practices
1. Use meaningful run names:
wandb.init(name=f"lr{config.lr}_bs{config.batch_size}")
2. Log everything you might need:
wandb.log({
"train_loss": loss,
"learning_rate": scheduler.get_last_lr()[0],
"gradient_norm": grad_norm,
})
3. Save artifacts (models, datasets):
# Save model checkpoint
artifact = wandb.Artifact(f"model-{epoch}", type="model")
artifact.add_file("checkpoint.pt")
wandb.log_artifact(artifact)
4. Tag runs for organization:
wandb.init(tags=["baseline", "resnet34", "cifar10"])
Capstone Connection
Tracking RLHF experiments:
When training with human feedback, tracking is essential:
wandb.log({
"reward_mean": rewards.mean(),
"reward_std": rewards.std(),
"kl_divergence": kl_div, # From reference model
"sycophancy_score": measure_sycophancy(responses),
"helpfulness_score": measure_helpfulness(responses),
})
You can track how sycophancy evolves during training:
- Does it increase with more RLHF?
- Does it correlate with reward?
- Do certain hyperparameters make it worse?
🎓 Tyla's Exercise
Why does wandb use "log_uniform" distribution for learning rate search instead of "uniform"?
What's the difference between grid search, random search, and Bayesian optimization? When would you use each?
Design a sweep configuration to find optimal hyperparameters for a transformer language model.
💻 Aaliyah's Exercise
Set up a complete experiment:
def run_experiment():
"""
1. Initialize wandb with a config
2. Train ResNet on CIFAR-10
3. Log: train loss, val loss, val accuracy per epoch
4. Log: learning rate schedule
5. Log: sample predictions as images
6. Save final model as artifact
Run 3 experiments with different learning rates.
Compare them in the wandb dashboard.
"""
pass
📚 Maneesha's Reflection
Experiment tracking can be seen as "keeping a lab notebook for ML." What makes a good lab notebook?
The reproducibility crisis in ML research is well-documented. How does tooling like wandb help or hurt?
"If you didn't log it, it didn't happen." Discuss this in the context of both ML research and human learning.