Hardware Ray Tracing on Apple Silicon
Leveraging M3/M4 Dedicated Ray Tracing Accelerators

Hardware Ray Tracing on Apple Silicon

Leveraging M3/M4 dedicated ray tracing accelerators

Leveraging M3/M4 dedicated ray tracing accelerators

The Ray Tracing Architecture

Apple introduced hardware ray tracing acceleration with M3. Unlike software ray tracing available since Metal 3, hardware RT provides dedicated silicon for BVH traversal and ray-triangle intersection.

Acceleration Structure Fundamentals

Building the BVH

// Create geometry descriptor
let geometryDesc = MTLAccelerationStructureTriangleGeometryDescriptor()
geometryDesc.vertexBuffer = vertexBuffer
geometryDesc.vertexStride = MemoryLayout<SIMD3<Float>>.stride
geometryDesc.indexBuffer = indexBuffer
geometryDesc.indexType = .uint32
geometryDesc.triangleCount = triangleCount

// Create primitive acceleration structure descriptor
let primitiveDesc = MTLPrimitiveAccelerationStructureDescriptor()
primitiveDesc.geometryDescriptors = [geometryDesc]

// Query sizes
let sizes = device.accelerationStructureSizes(descriptor: primitiveDesc)

// Allocate structures
let accelerationStructure = device.makeAccelerationStructure(size: sizes.accelerationStructureSize)!
let scratchBuffer = device.makeBuffer(length: sizes.buildScratchBufferSize)!

// Build
let encoder = commandBuffer.makeAccelerationStructureCommandEncoder()!
encoder.build(accelerationStructure: accelerationStructure,
              descriptor: primitiveDesc,
              scratchBuffer: scratchBuffer,
              scratchBufferOffset: 0)
encoder.endEncoding()

Instance Acceleration Structures

For scenes with multiple objects:

// Create instance descriptors
var instances: [MTLAccelerationStructureInstanceDescriptor] = []

for (index, object) in sceneObjects.enumerated() {
    var instance = MTLAccelerationStructureInstanceDescriptor()
    instance.accelerationStructureIndex = UInt32(index)
    instance.transformationMatrix = object.transform.metalMatrix
    instance.mask = object.visibilityMask
    instance.options = object.opaque ? .opaque : .nonOpaque
    instances.append(instance)
}

// Create instance buffer
let instanceBuffer = device.makeBuffer(
    bytes: &instances,
    length: instances.count * MemoryLayout<MTLAccelerationStructureInstanceDescriptor>.stride
)!

// Build instance acceleration structure
let instanceDesc = MTLInstanceAccelerationStructureDescriptor()
instanceDesc.instancedAccelerationStructures = primitiveStructures
instanceDesc.instanceCount = instances.count
instanceDesc.instanceDescriptorBuffer = instanceBuffer

Ray Tracing in Shaders

Basic Ray Queries

#include <metal_raytracing>
using namespace metal::raytracing;

kernel void trace_primary_rays(
    instance_acceleration_structure scene [[buffer(0)]],
    device Ray* rays                      [[buffer(1)]],
    device Intersection* hits             [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    Ray ray = rays[tid];

    // Create intersector
    intersector<triangle_data, instancing> inter;
    inter.accept_any_intersection(false);  // Find closest hit
    inter.assume_geometry_type(geometry_type::triangle);

    // Trace ray
    intersection_result<triangle_data, instancing> result;
    result = inter.intersect(ray, scene);

    if (result.type == intersection_type::triangle) {
        hits[tid].distance = result.distance;
        hits[tid].primitiveIndex = result.primitive_id;
        hits[tid].instanceIndex = result.instance_id;
        hits[tid].barycentrics = result.triangle_barycentric_coord;
    } else {
        hits[tid].distance = INFINITY;
    }
}

Custom Intersection Functions

For non-triangle geometry:

[[intersection(bounding_box, triangle_data)]]
BoundingBoxIntersection sphere_intersection(
    float3 origin               [[origin]],
    float3 direction            [[direction]],
    float minDistance           [[min_distance]],
    float maxDistance           [[max_distance]],
    uint primitiveIndex         [[primitive_id]],
    device const Sphere* spheres [[buffer(0)]]
) {
    Sphere sphere = spheres[primitiveIndex];
    float3 oc = origin - sphere.center;

    float a = dot(direction, direction);
    float b = 2.0 * dot(oc, direction);
    float c = dot(oc, oc) - sphere.radius * sphere.radius;
    float discriminant = b * b - 4 * a * c;

    BoundingBoxIntersection result;

    if (discriminant < 0) {
        result.accept = false;
    } else {
        float t = (-b - sqrt(discriminant)) / (2.0 * a);
        if (t < minDistance || t > maxDistance) {
            t = (-b + sqrt(discriminant)) / (2.0 * a);
        }

        if (t >= minDistance && t <= maxDistance) {
            result.accept = true;
            result.distance = t;
        } else {
            result.accept = false;
        }
    }

    return result;
}

Path Tracing Implementation

Complete Path Tracer Kernel

struct PathState {
    float3 radiance;
    float3 throughput;
    Ray ray;
    uint depth;
    uint rngState;
};

kernel void path_trace(
    instance_acceleration_structure scene [[buffer(0)]],
    device const Material* materials      [[buffer(1)]],
    device const Light* lights            [[buffer(2)]],
    texture2d<float, access::write> output [[texture(0)]],
    constant SceneParams& params          [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    // Initialize RNG
    uint rngState = hash(tid.x + tid.y * params.width + params.frameIndex * params.width * params.height);

    // Generate camera ray
    float2 pixel = float2(tid) + float2(rand(rngState), rand(rngState));
    float2 uv = pixel / float2(params.width, params.height);
    Ray ray = generateCameraRay(params.camera, uv);

    float3 radiance = float3(0.0);
    float3 throughput = float3(1.0);

    intersector<triangle_data, instancing> inter;
    inter.assume_geometry_type(geometry_type::triangle);

    for (uint bounce = 0; bounce < params.maxBounces; bounce++) {
        intersection_result<triangle_data, instancing> hit;
        hit = inter.intersect(ray, scene);

        if (hit.type != intersection_type::triangle) {
            // Sky contribution
            radiance += throughput * sampleEnvironment(ray.direction);
            break;
        }

        // Get hit information
        float3 position = ray.origin + ray.direction * hit.distance;
        uint materialIndex = getMaterialIndex(hit.instance_id, hit.primitive_id);
        Material mat = materials[materialIndex];
        float3 normal = interpolateNormal(hit);

        // Direct lighting (NEE)
        float3 directLight = evaluateDirectLighting(
            position, normal, mat, lights, params.lightCount,
            scene, inter, rngState
        );
        radiance += throughput * directLight;

        // Sample BSDF for next bounce
        float3 wo = -ray.direction;
        float3 wi;
        float pdf;
        float3 f = sampleBSDF(mat, normal, wo, wi, pdf, rngState);

        if (pdf < 1e-6) break;

        throughput *= f * abs(dot(wi, normal)) / pdf;

        // Russian roulette
        if (bounce > 3) {
            float p = min(max3(throughput), 0.95);
            if (rand(rngState) > p) break;
            throughput /= p;
        }

        // Setup next ray
        ray.origin = position + normal * 1e-4;
        ray.direction = wi;
        ray.min_distance = 0;
        ray.max_distance = INFINITY;
    }

    // Accumulate
    float3 previous = output.read(tid).rgb;
    float blend = 1.0 / float(params.frameIndex + 1);
    output.write(float4(mix(previous, radiance, blend), 1.0), tid);
}

Performance Characteristics

M3 vs M4 Ray Tracing Performance

OperationM3 ProM4 ProImprovement
BVH Build (1M tris)45ms32ms29%
Primary Rays (4K)2.1ms1.4ms33%
Path Trace (1spp)8.3ms5.6ms33%
Any-Hit Queries0.8ms0.5ms38%

Optimization Guidelines

  1. Batch ray traces: Dispatch millions of rays together
  2. Use any-hit for shadows: inter.accept_any_intersection(true)
  3. Compact BVH periodically: Reduces memory, improves traversal
  4. Stream geometry updates: Use refitting for animated scenes