Ray Tracing: Batched Operations
Single operations are slow. Batched operations are fast.
This chapter teaches you to eliminate loops by thinking in whole-tensor operations.
The Performance Problem
# SLOW: Loop over rays
results = []
for ray in rays:
results.append(intersect_ray_1d(ray, segment))
# FAST: Process all rays at once
results = intersect_rays_batched(rays, segment)
On a GPU, the batched version can be 1000x faster because:
- GPUs execute many operations in parallel
- Memory is accessed contiguously
- Python loop overhead is eliminated
Broadcasting for Batched Intersection
Recall our intersection equation:
$$\begin{pmatrix} D_x & (L_1 - L_2)_x \ D_y & (L_1 - L_2)_y \end{pmatrix} \begin{pmatrix} u \ v \end{pmatrix} = \begin{pmatrix} (L_1 - O)_x \ (L_1 - O)_y \end{pmatrix}$$
For many rays against one segment:
Dhas shape(n_rays, 2)— different for each rayL1 - L2has shape(2,)— same for all raysL1 - Ohas shape(n_rays, 2)— different for each ray
def intersect_rays_batched(
rays: t.Tensor, # (n_rays, 2, 3)
segment: t.Tensor # (2, 3)
) -> t.Tensor:
"""Returns (n_rays,) boolean tensor."""
# Extract 2D coordinates
rays_2d = rays[:, :, :2] # (n_rays, 2, 2)
segment_2d = segment[:, :2] # (2, 2)
O = rays_2d[:, 0] # (n_rays, 2)
D = rays_2d[:, 1] # (n_rays, 2)
L1 = segment_2d[0] # (2,)
L2 = segment_2d[1] # (2,)
# Build batched matrix: (n_rays, 2, 2)
# Column 0: D (different per ray)
# Column 1: L1 - L2 (same for all, broadcasts)
diff = (L1 - L2).unsqueeze(0).expand(len(rays), -1) # (n_rays, 2)
A = t.stack([D, diff], dim=-1) # (n_rays, 2, 2)
# Build batched RHS: (n_rays, 2)
b = L1 - O # Broadcasts: (2,) - (n_rays, 2) = (n_rays, 2)
# Solve all systems at once
try:
sol = t.linalg.solve(A, b) # (n_rays, 2)
except RuntimeError:
# Handle singular matrices individually
return handle_singular(rays, segment)
u, v = sol[:, 0], sol[:, 1]
return (u >= 0) & (v >= 0) & (v <= 1)
Using einops for Clarity
The einops library makes batched operations readable:
import einops
# Reshape for batched matrix multiply
# From (batch, points, dims) to (batch, dims, points)
rays_T = einops.rearrange(rays, 'b p d -> b d p')
# Repeat segment for each ray
segment_batched = einops.repeat(
segment, 'p d -> b p d', b=len(rays)
)
# Reduce across dimension
means = einops.reduce(rays, 'b p d -> b d', 'mean')
The pattern: Describe what you want, not how to get it.
Many Rays × Many Segments
What if we have multiple segments too?
def intersect_all(
rays: t.Tensor, # (n_rays, 2, 3)
segments: t.Tensor # (n_segments, 2, 3)
) -> t.Tensor:
"""Returns (n_rays, n_segments) boolean tensor."""
# Add dimensions for broadcasting
rays_exp = rays[:, None, :, :] # (n_rays, 1, 2, 3)
segs_exp = segments[None, :, :, :] # (1, n_segments, 2, 3)
# Now both broadcast to (n_rays, n_segments, 2, 3)
# ... rest of intersection logic
This is the power of broadcasting: express "all pairs" without nested loops.
Logical Reductions
After computing intersections, you often need to aggregate:
# Does each ray hit ANY segment?
hits_any = intersections.any(dim=1) # (n_rays,)
# Does each ray hit ALL segments? (unlikely but possible)
hits_all = intersections.all(dim=1) # (n_rays,)
# How many segments does each ray hit?
hit_count = intersections.sum(dim=1) # (n_rays,)
# Which segment does each ray hit first? (needs distance, not just bool)
first_hit = distances.argmin(dim=1) # (n_rays,)
Rendering a 1D Image
Now we can render! Each pixel corresponds to a ray. The pixel is bright if the ray hits an object.
def render_1d(
rays: t.Tensor, # (n_pixels, 2, 3)
segments: t.Tensor # (n_segments, 2, 3)
) -> t.Tensor:
"""
Returns: (n_pixels,) float tensor
1.0 if ray hits any segment, 0.0 otherwise
"""
intersections = intersect_all(rays, segments) # (n_pixels, n_segments)
return intersections.any(dim=1).float()
Capstone Connection
Batched operations in attention:
# Attention computes ALL query-key pairs at once
Q = model.W_Q(residual) # (batch, seq, d_model) → (batch, seq, n_heads, d_head)
K = model.W_K(residual) # Same shape
# Compute all attention scores simultaneously
# This is like "does each query intersect each key?"
scores = einops.einsum(
Q, K,
"batch seq_q heads d, batch seq_k heads d -> batch heads seq_q seq_k"
)
The attention pattern is a heatmap of "intersections" between queries and keys.
🎓 Tyla's Exercise
In
intersect_all, what is the memory complexity? If n_rays = 1000 and n_segments = 1000, how many bytes for the intermediate tensors?Why do we use
[:, None, :, :]instead ofunsqueeze? Are they equivalent?Prove that
(a & b).any()equalsa.any() and b.any()— or find a counterexample.
💻 Aaliyah's Exercise
Implement a full 2D renderer:
def make_rays_2d(
num_pixels_y: int,
num_pixels_z: int,
y_limit: float,
z_limit: float
) -> t.Tensor:
"""
Returns: (num_pixels_y * num_pixels_z, 2, 3)
Rays through a 2D grid of pixels.
Hint: Use torch.meshgrid or einops to create the grid.
"""
pass
def render_2d(rays, triangles) -> t.Tensor:
"""
rays: (n_pixels, 2, 3)
triangles: (n_triangles, 3, 3) - each triangle has 3 vertices
Returns: (height, width) image
"""
pass
📚 Maneesha's Reflection
"Batching" appears everywhere: batch normalization, mini-batch gradient descent, batched matrix operations. What's the common thread?
The trade-off with batching is memory vs. speed. When might you intentionally NOT batch operations?
How would you explain broadcasting to someone who finds it unintuitive? What metaphor would you use?