Ray Tracing: 1D Image Rendering
Ray tracing teaches you to think in batched operations—the core skill for efficient PyTorch code.
You'll build a simple graphics renderer, starting with the basics and working up to rendering a 3D Pikachu.
Why Ray Tracing?
This isn't about graphics. It's about:
- Batched operations: Processing many rays simultaneously
- Linear algebra: Solving systems of equations with tensors
- Broadcasting: Making dimensions work together
- Debugging: Finding errors in tensor operations
These exact skills transfer directly to transformers and interpretability work.
The Setup
Our renderer has three components:
- Camera: A point at the origin (0, 0, 0)
- Screen: A plane at x=1
- Objects: Line segments (2D) or triangles (3D)
A ray goes from the camera through a screen pixel. If it hits an object, the pixel lights up.
Camera Screen Object
O ────────────→ • ─────────→ ═══
(0,0,0) x=1
Parametric Rays
A ray from origin O in direction D can be written as:
$$R(u) = O + u \cdot D \quad \text{for } u \in [0, \infty)$$
At u=0, we're at the origin. As u increases, we move along the ray.
import torch as t
def make_rays_1d(num_pixels: int, y_limit: float) -> t.Tensor:
"""
Create rays from origin through a 1D screen.
num_pixels: Number of pixels (and rays)
y_limit: At x=1, rays span from -y_limit to +y_limit
Returns: shape (num_pixels, 2, 3)
- 2 points per ray: [origin, direction]
- 3 coordinates: [x, y, z]
"""
rays = t.zeros((num_pixels, 2, 3), dtype=t.float32)
# All rays start at origin (already zeros)
# Direction points hit x=1, vary in y
rays[:, 1, 0] = 1 # x = 1
rays[:, 1, 1] = t.linspace(-y_limit, y_limit, num_pixels)
# z = 0 (already zeros)
return rays
Key insight: We create ALL rays at once using linspace. No loops.
Ray-Segment Intersection
Given:
- Ray: origin O, direction D
- Segment: endpoints L₁, L₂
The ray hits the segment if there exist u ≥ 0 and v ∈ [0, 1] such that:
$$O + uD = L_1 + v(L_2 - L_1)$$
Rearranging into matrix form:
$$\begin{pmatrix} D_x & (L_1 - L_2)_x \ D_y & (L_1 - L_2)_y \end{pmatrix} \begin{pmatrix} u \ v \end{pmatrix} = \begin{pmatrix} (L_1 - O)_x \ (L_1 - O)_y \end{pmatrix}$$
def intersect_ray_1d(ray: t.Tensor, segment: t.Tensor) -> bool:
"""
Check if a ray intersects a segment.
ray: shape (2, 3) - [origin, direction]
segment: shape (2, 3) - [L1, L2]
Returns: True if intersection exists
"""
# Use only x and y (ignore z for 2D)
ray_2d = ray[:, :2]
seg_2d = segment[:, :2]
O, D = ray_2d
L1, L2 = seg_2d
# Build matrix equation: A @ [u, v] = b
A = t.stack([D, L1 - L2], dim=-1) # Shape: (2, 2)
b = L1 - O # Shape: (2,)
try:
sol = t.linalg.solve(A, b)
except RuntimeError:
# Parallel lines - no intersection
return False
u, v = sol[0].item(), sol[1].item()
return u >= 0 and 0 <= v <= 1
When Does solve Fail?
torch.linalg.solve fails when the matrix is singular (determinant = 0).
This happens when the ray and segment are parallel:
# Parallel case: ray direction is parallel to segment
ray = t.tensor([
[0.0, 0.0, 0.0], # Origin
[1.0, 0.0, 0.0], # Direction: along x-axis
])
segment = t.tensor([
[1.0, 1.0, 0.0], # L1
[2.0, 1.0, 0.0], # L2: also along x-axis
])
# D = [1, 0], L1 - L2 = [-1, 0]
# Matrix = [[1, -1], [0, 0]] - determinant = 0!
Always wrap solve in try/except when inputs might be parallel.
Capstone Connection
Why ray-segment intersection matters for sycophancy:
The attention mechanism is essentially asking: "Does this query vector 'intersect' with this key vector?"
# Attention score = query · key
# Geometrically: how aligned are these vectors?
query = t.randn(64) # What this position is looking for
key = t.randn(64) # What this position contains
score = query @ key # High if aligned, low if orthogonal
Ray tracing builds geometric intuition for these operations.
🎓 Tyla's Exercise
Prove that if
det(A) = 0, the ray and segment are parallel.In the intersection formula, what does
u < 0mean geometrically?What happens if
v < 0orv > 1? Draw a diagram.
After completing: The constraints on u and v define when an intersection "counts." How is this similar to masking in attention?
💻 Aaliyah's Exercise
Implement the batched version:
def intersect_rays_1d(
rays: t.Tensor, # Shape: (n_rays, 2, 3)
segment: t.Tensor # Shape: (2, 3)
) -> t.Tensor:
"""
Returns: shape (n_rays,) boolean tensor
True for each ray that intersects the segment.
Hint: Use broadcasting to solve all rays at once.
torch.linalg.solve can take batched inputs!
"""
pass
Test with:
rays = make_rays_1d(9, 1.0)
segment = t.tensor([[0.5, 0.5, 0.0], [0.5, -0.5, 0.0]])
result = intersect_rays_1d(rays, segment)
print(result) # Some True, some False
📚 Maneesha's Reflection
Why do we parameterize rays as
O + u*Dinstead of just storing two points?The same ray-object intersection logic underlies both computer graphics and attention mechanisms. What's the common abstraction?
If you were teaching tensor operations to a visual learner, how would ray tracing help or hurt their understanding?