Ray Tracing: Triangles & 3D Rendering
Every 3D mesh is made of triangles. Your Pikachu will have 412 of them.
This chapter extends ray tracing to 3D and renders actual objects.
Why Triangles?
Triangles are the universal primitive for 3D graphics because:
- Always planar: Any 3 points define a plane
- Simple intersection: Well-defined inside/outside
- Easy interpolation: Barycentric coordinates
- Universal approximation: Any surface ≈ enough triangles
A complex surface: Approximated by triangles:
~~~ /\ /\ /\
~~~~~ / \/ \/ \
~~~~~~~ /____________\
Parametric Triangles
A triangle with vertices A, B, C can be written as:
$$P(s, t) = A + s(B - A) + t(C - A)$$
where $s \geq 0$, $t \geq 0$, and $s + t \leq 1$.
The constraints ensure we stay inside the triangle:
- $s = 0, t = 0$ → point A
- $s = 1, t = 0$ → point B
- $s = 0, t = 1$ → point C
- $s + t = 1$ → edge BC
Ray-Triangle Intersection
Combining the ray equation $R(u) = O + uD$ with the triangle equation:
$$O + uD = A + s(B - A) + t(C - A)$$
Rearranging:
$$\begin{pmatrix} -D & B-A & C-A \end{pmatrix} \begin{pmatrix} u \ s \ t \end{pmatrix} = O - A$$
This is a 3×3 system! We solve it and check:
- $u \geq 0$ (ray goes forward)
- $s \geq 0$ (inside triangle)
- $t \geq 0$ (inside triangle)
- $s + t \leq 1$ (inside triangle)
def intersect_ray_triangle(
ray: t.Tensor, # (2, 3): [origin, direction]
triangle: t.Tensor # (3, 3): [A, B, C]
) -> tuple[bool, float]:
"""
Returns: (hit, distance)
- hit: True if ray intersects triangle
- distance: u value (distance along ray) if hit, else inf
"""
O, D = ray
A, B, C = triangle
# Build matrix: columns are [-D, B-A, C-A]
mat = t.stack([-D, B - A, C - A], dim=-1) # (3, 3)
vec = O - A # (3,)
try:
sol = t.linalg.solve(mat, vec)
except RuntimeError:
return False, float('inf')
u, s, t_val = sol.tolist()
inside = (u >= 0) and (s >= 0) and (t_val >= 0) and (s + t_val <= 1)
return inside, u if inside else float('inf')
Batched Triangle Intersection
For a full scene, we intersect each ray with each triangle:
def intersect_rays_triangles(
rays: t.Tensor, # (n_rays, 2, 3)
triangles: t.Tensor # (n_triangles, 3, 3)
) -> tuple[t.Tensor, t.Tensor]:
"""
Returns:
- hits: (n_rays, n_triangles) bool
- distances: (n_rays, n_triangles) float
"""
n_rays = rays.shape[0]
n_tri = triangles.shape[0]
# Expand for broadcasting
O = rays[:, 0, :] # (n_rays, 3)
D = rays[:, 1, :] # (n_rays, 3)
A = triangles[:, 0, :] # (n_tri, 3)
B = triangles[:, 1, :] # (n_tri, 3)
C = triangles[:, 2, :] # (n_tri, 3)
# Add dimensions for broadcasting
# O: (n_rays, 1, 3), A: (1, n_tri, 3) → broadcasts to (n_rays, n_tri, 3)
O = O[:, None, :]
D = D[:, None, :]
A = A[None, :, :]
B = B[None, :, :]
C = C[None, :, :]
# Build batched matrices: (n_rays, n_tri, 3, 3)
col1 = -D # (n_rays, n_tri, 3)
col2 = B - A # (n_rays, n_tri, 3)
col3 = C - A # (n_rays, n_tri, 3)
mat = t.stack([col1, col2, col3], dim=-1)
vec = O - A # (n_rays, n_tri, 3)
# Solve all systems
sol = t.linalg.solve(mat, vec) # (n_rays, n_tri, 3)
u = sol[..., 0]
s = sol[..., 1]
t_val = sol[..., 2]
hits = (u >= 0) & (s >= 0) & (t_val >= 0) & (s + t_val <= 1)
distances = t.where(hits, u, t.tensor(float('inf')))
return hits, distances
Rendering the Image
For each pixel:
- Cast a ray through it
- Find the closest triangle hit
- Color based on the result
def render_mesh(
rays: t.Tensor, # (height, width, 2, 3)
triangles: t.Tensor, # (n_triangles, 3, 3)
) -> t.Tensor:
"""Returns (height, width) grayscale image."""
h, w = rays.shape[:2]
rays_flat = rays.reshape(-1, 2, 3) # (h*w, 2, 3)
hits, distances = intersect_rays_triangles(rays_flat, triangles)
# For each ray, find minimum distance (closest hit)
min_dist = distances.min(dim=1).values # (h*w,)
# Pixel is lit if any triangle was hit
image = (min_dist < float('inf')).float().reshape(h, w)
return image
Loading and Rendering Pikachu
# Load mesh (412 triangles)
with open("pikachu.pt", "rb") as f:
triangles = t.load(f)
print(f"Loaded {triangles.shape[0]} triangles")
# Create rays
rays = make_rays_2d(
num_pixels_y=120,
num_pixels_z=120,
y_limit=1.0,
z_limit=1.0
)
# Render!
image = render_mesh(rays.reshape(120, 120, 2, 3), triangles)
import matplotlib.pyplot as plt
plt.imshow(image.numpy(), cmap='gray')
plt.title("Pikachu!")
plt.show()
Capstone Connection
Triangle meshes and embedding spaces:
In transformer interpretability, we study high-dimensional embedding spaces. Just as a Pikachu can be approximated by triangles, the structure of embedding space can be approximated by simpler geometric shapes.
When we look for "sycophancy features" in a model:
- The feature direction is like a triangle's normal vector
- Activation patterns are like ray intersections
- The embedding space geometry tells us what the model "knows"
🎓 Tyla's Exercise
Prove that $s \geq 0$, $t \geq 0$, $s + t \leq 1$ defines exactly the interior of the triangle.
What happens if the triangle is degenerate (all three vertices on a line)? How would you detect and handle this?
In the batched version, why do we use
t.linalg.solveinstead of computing the matrix inverse?
💻 Aaliyah's Exercise
Add depth-based shading:
def render_with_depth(
rays: t.Tensor,
triangles: t.Tensor
) -> t.Tensor:
"""
Returns (height, width) image where:
- Closer objects are brighter
- Objects at max_dist or beyond are black
Normalize distances to [0, 1] range.
"""
pass
Bonus: Add surface normals for proper lighting (the normal of a triangle is the cross product of two edges).
📚 Maneesha's Reflection
3D graphics evolved from "render everything" to "render only what's visible" (culling, occlusion). How does this parallel efficiency improvements in attention (sparse attention, local attention)?
The move from triangles to neural radiance fields (NeRFs) represents a shift from explicit to implicit representations. What other domains have seen similar shifts?
If you were designing a curriculum that used ray tracing to teach ML concepts, what sequence would you use?