Ray Tracing: Batched Operations

Single operations are slow. Batched operations are fast.

This chapter teaches you to eliminate loops by thinking in whole-tensor operations.


The Performance Problem

# SLOW: Loop over rays
results = []
for ray in rays:
    results.append(intersect_ray_1d(ray, segment))
# FAST: Process all rays at once
results = intersect_rays_batched(rays, segment)

On a GPU, the batched version can be 1000x faster because:

  1. GPUs execute many operations in parallel
  2. Memory is accessed contiguously
  3. Python loop overhead is eliminated

Broadcasting for Batched Intersection

Recall our intersection equation:

$$\begin{pmatrix} D_x & (L_1 - L_2)_x \ D_y & (L_1 - L_2)_y \end{pmatrix} \begin{pmatrix} u \ v \end{pmatrix} = \begin{pmatrix} (L_1 - O)_x \ (L_1 - O)_y \end{pmatrix}$$

For many rays against one segment:

def intersect_rays_batched(
    rays: t.Tensor,     # (n_rays, 2, 3)
    segment: t.Tensor   # (2, 3)
) -> t.Tensor:
    """Returns (n_rays,) boolean tensor."""

    # Extract 2D coordinates
    rays_2d = rays[:, :, :2]      # (n_rays, 2, 2)
    segment_2d = segment[:, :2]   # (2, 2)

    O = rays_2d[:, 0]   # (n_rays, 2)
    D = rays_2d[:, 1]   # (n_rays, 2)
    L1 = segment_2d[0]  # (2,)
    L2 = segment_2d[1]  # (2,)

    # Build batched matrix: (n_rays, 2, 2)
    # Column 0: D (different per ray)
    # Column 1: L1 - L2 (same for all, broadcasts)
    diff = (L1 - L2).unsqueeze(0).expand(len(rays), -1)  # (n_rays, 2)
    A = t.stack([D, diff], dim=-1)  # (n_rays, 2, 2)

    # Build batched RHS: (n_rays, 2)
    b = L1 - O  # Broadcasts: (2,) - (n_rays, 2) = (n_rays, 2)

    # Solve all systems at once
    try:
        sol = t.linalg.solve(A, b)  # (n_rays, 2)
    except RuntimeError:
        # Handle singular matrices individually
        return handle_singular(rays, segment)

    u, v = sol[:, 0], sol[:, 1]
    return (u >= 0) & (v >= 0) & (v <= 1)

Using einops for Clarity

The einops library makes batched operations readable:

import einops

# Reshape for batched matrix multiply
# From (batch, points, dims) to (batch, dims, points)
rays_T = einops.rearrange(rays, 'b p d -> b d p')

# Repeat segment for each ray
segment_batched = einops.repeat(
    segment, 'p d -> b p d', b=len(rays)
)

# Reduce across dimension
means = einops.reduce(rays, 'b p d -> b d', 'mean')

The pattern: Describe what you want, not how to get it.


Many Rays × Many Segments

What if we have multiple segments too?

def intersect_all(
    rays: t.Tensor,      # (n_rays, 2, 3)
    segments: t.Tensor   # (n_segments, 2, 3)
) -> t.Tensor:
    """Returns (n_rays, n_segments) boolean tensor."""

    # Add dimensions for broadcasting
    rays_exp = rays[:, None, :, :]       # (n_rays, 1, 2, 3)
    segs_exp = segments[None, :, :, :]   # (1, n_segments, 2, 3)

    # Now both broadcast to (n_rays, n_segments, 2, 3)
    # ... rest of intersection logic

This is the power of broadcasting: express "all pairs" without nested loops.


Logical Reductions

After computing intersections, you often need to aggregate:

# Does each ray hit ANY segment?
hits_any = intersections.any(dim=1)  # (n_rays,)

# Does each ray hit ALL segments? (unlikely but possible)
hits_all = intersections.all(dim=1)  # (n_rays,)

# How many segments does each ray hit?
hit_count = intersections.sum(dim=1)  # (n_rays,)

# Which segment does each ray hit first? (needs distance, not just bool)
first_hit = distances.argmin(dim=1)  # (n_rays,)

Rendering a 1D Image

Now we can render! Each pixel corresponds to a ray. The pixel is bright if the ray hits an object.

def render_1d(
    rays: t.Tensor,      # (n_pixels, 2, 3)
    segments: t.Tensor   # (n_segments, 2, 3)
) -> t.Tensor:
    """
    Returns: (n_pixels,) float tensor
    1.0 if ray hits any segment, 0.0 otherwise
    """
    intersections = intersect_all(rays, segments)  # (n_pixels, n_segments)
    return intersections.any(dim=1).float()

Capstone Connection

Batched operations in attention:

# Attention computes ALL query-key pairs at once
Q = model.W_Q(residual)  # (batch, seq, d_model) → (batch, seq, n_heads, d_head)
K = model.W_K(residual)  # Same shape

# Compute all attention scores simultaneously
# This is like "does each query intersect each key?"
scores = einops.einsum(
    Q, K,
    "batch seq_q heads d, batch seq_k heads d -> batch heads seq_q seq_k"
)

The attention pattern is a heatmap of "intersections" between queries and keys.


🎓 Tyla's Exercise

  1. In intersect_all, what is the memory complexity? If n_rays = 1000 and n_segments = 1000, how many bytes for the intermediate tensors?

  2. Why do we use [:, None, :, :] instead of unsqueeze? Are they equivalent?

  3. Prove that (a & b).any() equals a.any() and b.any() — or find a counterexample.


💻 Aaliyah's Exercise

Implement a full 2D renderer:

def make_rays_2d(
    num_pixels_y: int,
    num_pixels_z: int,
    y_limit: float,
    z_limit: float
) -> t.Tensor:
    """
    Returns: (num_pixels_y * num_pixels_z, 2, 3)
    Rays through a 2D grid of pixels.

    Hint: Use torch.meshgrid or einops to create the grid.
    """
    pass

def render_2d(rays, triangles) -> t.Tensor:
    """
    rays: (n_pixels, 2, 3)
    triangles: (n_triangles, 3, 3) - each triangle has 3 vertices

    Returns: (height, width) image
    """
    pass

📚 Maneesha's Reflection

  1. "Batching" appears everywhere: batch normalization, mini-batch gradient descent, batched matrix operations. What's the common thread?

  2. The trade-off with batching is memory vs. speed. When might you intentionally NOT batch operations?

  3. How would you explain broadcasting to someone who finds it unintuitive? What metaphor would you use?