Optimization: Weights & Biases

Hyperparameter tuning without tracking is guesswork. Weights & Biases makes experiments reproducible.


Why Track Experiments?

Without tracking:

With tracking:


Setting Up wandb

import wandb

# Initialize (do this once per project)
wandb.init(
    project="arena-mnist",
    config={
        "learning_rate": 0.001,
        "batch_size": 64,
        "epochs": 10,
        "architecture": "ResNet34",
    }
)

# Access config
config = wandb.config
print(f"Training with lr={config.learning_rate}")

Logging Metrics

for epoch in range(config.epochs):
    for batch_idx, (x, y) in enumerate(train_loader):
        loss = train_step(model, x, y)

        # Log training metrics
        wandb.log({
            "train_loss": loss.item(),
            "epoch": epoch,
            "step": epoch * len(train_loader) + batch_idx,
        })

    # Log validation metrics after each epoch
    val_loss, val_acc = evaluate(model, val_loader)
    wandb.log({
        "val_loss": val_loss,
        "val_accuracy": val_acc,
        "epoch": epoch,
    })

# Finish the run
wandb.finish()

Using Config Objects

Clean separation of hyperparameters:

from dataclasses import dataclass

@dataclass
class TrainConfig:
    learning_rate: float = 0.001
    batch_size: int = 64
    epochs: int = 10
    weight_decay: float = 0.01
    momentum: float = 0.9
    optimizer: str = "adam"

def train(config: TrainConfig):
    wandb.init(
        project="arena-mnist",
        config=vars(config)  # Convert dataclass to dict
    )

    # ... training code using config ...

    wandb.finish()

# Run with different configs
train(TrainConfig(learning_rate=0.001))
train(TrainConfig(learning_rate=0.01, optimizer="sgd"))

Hyperparameter Sweeps

Automated search over configurations:

sweep_config = {
    "method": "bayes",  # or "grid", "random"
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "learning_rate": {
            "min": 1e-5,
            "max": 1e-2,
            "distribution": "log_uniform_values"
        },
        "batch_size": {"values": [32, 64, 128]},
        "weight_decay": {
            "min": 0.001,
            "max": 0.1,
            "distribution": "log_uniform_values"
        },
    }
}

# Create sweep
sweep_id = wandb.sweep(sweep_config, project="arena-mnist")

# Run agent
def train_sweep():
    wandb.init()
    config = wandb.config

    # Train with sweep config
    model = train_model(
        lr=config.learning_rate,
        batch_size=config.batch_size,
        weight_decay=config.weight_decay
    )

wandb.agent(sweep_id, function=train_sweep, count=20)

Logging Images and Media

# Log images
images = x[:8]  # First 8 images from batch
wandb.log({
    "input_images": wandb.Image(images),
})

# Log predictions
predictions = model(images)
wandb.log({
    "predictions": wandb.Histogram(predictions.detach().cpu().numpy()),
})

# Log model architecture
wandb.watch(model, log="all")  # Logs gradients and parameters

Comparing Runs

In the wandb dashboard:

  1. Select multiple runs
  2. Click "Compare"
  3. View:
    • Loss curves overlaid
    • Config differences highlighted
    • Final metrics table
# You can also do this programmatically
api = wandb.Api()
runs = api.runs("your-username/arena-mnist")

for run in runs:
    print(f"{run.name}: {run.summary.get('val_accuracy', 'N/A')}")

Best Practices

1. Use meaningful run names:

wandb.init(name=f"lr{config.lr}_bs{config.batch_size}")

2. Log everything you might need:

wandb.log({
    "train_loss": loss,
    "learning_rate": scheduler.get_last_lr()[0],
    "gradient_norm": grad_norm,
})

3. Save artifacts (models, datasets):

# Save model checkpoint
artifact = wandb.Artifact(f"model-{epoch}", type="model")
artifact.add_file("checkpoint.pt")
wandb.log_artifact(artifact)

4. Tag runs for organization:

wandb.init(tags=["baseline", "resnet34", "cifar10"])

Capstone Connection

Tracking RLHF experiments:

When training with human feedback, tracking is essential:

wandb.log({
    "reward_mean": rewards.mean(),
    "reward_std": rewards.std(),
    "kl_divergence": kl_div,  # From reference model
    "sycophancy_score": measure_sycophancy(responses),
    "helpfulness_score": measure_helpfulness(responses),
})

You can track how sycophancy evolves during training:


🎓 Tyla's Exercise

  1. Why does wandb use "log_uniform" distribution for learning rate search instead of "uniform"?

  2. What's the difference between grid search, random search, and Bayesian optimization? When would you use each?

  3. Design a sweep configuration to find optimal hyperparameters for a transformer language model.


💻 Aaliyah's Exercise

Set up a complete experiment:

def run_experiment():
    """
    1. Initialize wandb with a config
    2. Train ResNet on CIFAR-10
    3. Log: train loss, val loss, val accuracy per epoch
    4. Log: learning rate schedule
    5. Log: sample predictions as images
    6. Save final model as artifact

    Run 3 experiments with different learning rates.
    Compare them in the wandb dashboard.
    """
    pass

📚 Maneesha's Reflection

  1. Experiment tracking can be seen as "keeping a lab notebook for ML." What makes a good lab notebook?

  2. The reproducibility crisis in ML research is well-documented. How does tooling like wandb help or hurt?

  3. "If you didn't log it, it didn't happen." Discuss this in the context of both ML research and human learning.