Leveraging M3/M4 Dedicated Ray Tracing Accelerators
Hardware Ray Tracing on Apple Silicon
Leveraging M3/M4 dedicated ray tracing accelerators
Leveraging M3/M4 dedicated ray tracing accelerators
The Ray Tracing Architecture
Apple introduced hardware ray tracing acceleration with M3. Unlike software ray tracing available since Metal 3, hardware RT provides dedicated silicon for BVH traversal and ray-triangle intersection.
Acceleration Structure Fundamentals
Building the BVH
// Create geometry descriptor
let geometryDesc = MTLAccelerationStructureTriangleGeometryDescriptor()
geometryDesc.vertexBuffer = vertexBuffer
geometryDesc.vertexStride = MemoryLayout<SIMD3<Float>>.stride
geometryDesc.indexBuffer = indexBuffer
geometryDesc.indexType = .uint32
geometryDesc.triangleCount = triangleCount
// Create primitive acceleration structure descriptor
let primitiveDesc = MTLPrimitiveAccelerationStructureDescriptor()
primitiveDesc.geometryDescriptors = [geometryDesc]
// Query sizes
let sizes = device.accelerationStructureSizes(descriptor: primitiveDesc)
// Allocate structures
let accelerationStructure = device.makeAccelerationStructure(size: sizes.accelerationStructureSize)!
let scratchBuffer = device.makeBuffer(length: sizes.buildScratchBufferSize)!
// Build
let encoder = commandBuffer.makeAccelerationStructureCommandEncoder()!
encoder.build(accelerationStructure: accelerationStructure,
descriptor: primitiveDesc,
scratchBuffer: scratchBuffer,
scratchBufferOffset: 0)
encoder.endEncoding()
Instance Acceleration Structures
For scenes with multiple objects:
// Create instance descriptors
var instances: [MTLAccelerationStructureInstanceDescriptor] = []
for (index, object) in sceneObjects.enumerated() {
var instance = MTLAccelerationStructureInstanceDescriptor()
instance.accelerationStructureIndex = UInt32(index)
instance.transformationMatrix = object.transform.metalMatrix
instance.mask = object.visibilityMask
instance.options = object.opaque ? .opaque : .nonOpaque
instances.append(instance)
}
// Create instance buffer
let instanceBuffer = device.makeBuffer(
bytes: &instances,
length: instances.count * MemoryLayout<MTLAccelerationStructureInstanceDescriptor>.stride
)!
// Build instance acceleration structure
let instanceDesc = MTLInstanceAccelerationStructureDescriptor()
instanceDesc.instancedAccelerationStructures = primitiveStructures
instanceDesc.instanceCount = instances.count
instanceDesc.instanceDescriptorBuffer = instanceBuffer
Ray Tracing in Shaders
Basic Ray Queries
#include <metal_raytracing>
using namespace metal::raytracing;
kernel void trace_primary_rays(
instance_acceleration_structure scene [[buffer(0)]],
device Ray* rays [[buffer(1)]],
device Intersection* hits [[buffer(2)]],
uint tid [[thread_position_in_grid]]
) {
Ray ray = rays[tid];
// Create intersector
intersector<triangle_data, instancing> inter;
inter.accept_any_intersection(false); // Find closest hit
inter.assume_geometry_type(geometry_type::triangle);
// Trace ray
intersection_result<triangle_data, instancing> result;
result = inter.intersect(ray, scene);
if (result.type == intersection_type::triangle) {
hits[tid].distance = result.distance;
hits[tid].primitiveIndex = result.primitive_id;
hits[tid].instanceIndex = result.instance_id;
hits[tid].barycentrics = result.triangle_barycentric_coord;
} else {
hits[tid].distance = INFINITY;
}
}
Custom Intersection Functions
For non-triangle geometry:
[[intersection(bounding_box, triangle_data)]]
BoundingBoxIntersection sphere_intersection(
float3 origin [[origin]],
float3 direction [[direction]],
float minDistance [[min_distance]],
float maxDistance [[max_distance]],
uint primitiveIndex [[primitive_id]],
device const Sphere* spheres [[buffer(0)]]
) {
Sphere sphere = spheres[primitiveIndex];
float3 oc = origin - sphere.center;
float a = dot(direction, direction);
float b = 2.0 * dot(oc, direction);
float c = dot(oc, oc) - sphere.radius * sphere.radius;
float discriminant = b * b - 4 * a * c;
BoundingBoxIntersection result;
if (discriminant < 0) {
result.accept = false;
} else {
float t = (-b - sqrt(discriminant)) / (2.0 * a);
if (t < minDistance || t > maxDistance) {
t = (-b + sqrt(discriminant)) / (2.0 * a);
}
if (t >= minDistance && t <= maxDistance) {
result.accept = true;
result.distance = t;
} else {
result.accept = false;
}
}
return result;
}
Path Tracing Implementation
Complete Path Tracer Kernel
struct PathState {
float3 radiance;
float3 throughput;
Ray ray;
uint depth;
uint rngState;
};
kernel void path_trace(
instance_acceleration_structure scene [[buffer(0)]],
device const Material* materials [[buffer(1)]],
device const Light* lights [[buffer(2)]],
texture2d<float, access::write> output [[texture(0)]],
constant SceneParams& params [[buffer(3)]],
uint2 tid [[thread_position_in_grid]]
) {
// Initialize RNG
uint rngState = hash(tid.x + tid.y * params.width + params.frameIndex * params.width * params.height);
// Generate camera ray
float2 pixel = float2(tid) + float2(rand(rngState), rand(rngState));
float2 uv = pixel / float2(params.width, params.height);
Ray ray = generateCameraRay(params.camera, uv);
float3 radiance = float3(0.0);
float3 throughput = float3(1.0);
intersector<triangle_data, instancing> inter;
inter.assume_geometry_type(geometry_type::triangle);
for (uint bounce = 0; bounce < params.maxBounces; bounce++) {
intersection_result<triangle_data, instancing> hit;
hit = inter.intersect(ray, scene);
if (hit.type != intersection_type::triangle) {
// Sky contribution
radiance += throughput * sampleEnvironment(ray.direction);
break;
}
// Get hit information
float3 position = ray.origin + ray.direction * hit.distance;
uint materialIndex = getMaterialIndex(hit.instance_id, hit.primitive_id);
Material mat = materials[materialIndex];
float3 normal = interpolateNormal(hit);
// Direct lighting (NEE)
float3 directLight = evaluateDirectLighting(
position, normal, mat, lights, params.lightCount,
scene, inter, rngState
);
radiance += throughput * directLight;
// Sample BSDF for next bounce
float3 wo = -ray.direction;
float3 wi;
float pdf;
float3 f = sampleBSDF(mat, normal, wo, wi, pdf, rngState);
if (pdf < 1e-6) break;
throughput *= f * abs(dot(wi, normal)) / pdf;
// Russian roulette
if (bounce > 3) {
float p = min(max3(throughput), 0.95);
if (rand(rngState) > p) break;
throughput /= p;
}
// Setup next ray
ray.origin = position + normal * 1e-4;
ray.direction = wi;
ray.min_distance = 0;
ray.max_distance = INFINITY;
}
// Accumulate
float3 previous = output.read(tid).rgb;
float blend = 1.0 / float(params.frameIndex + 1);
output.write(float4(mix(previous, radiance, blend), 1.0), tid);
}
Performance Characteristics
M3 vs M4 Ray Tracing Performance
| Operation | M3 Pro | M4 Pro | Improvement |
|---|---|---|---|
| BVH Build (1M tris) | 45ms | 32ms | 29% |
| Primary Rays (4K) | 2.1ms | 1.4ms | 33% |
| Path Trace (1spp) | 8.3ms | 5.6ms | 33% |
| Any-Hit Queries | 0.8ms | 0.5ms | 38% |
Optimization Guidelines
- Batch ray traces: Dispatch millions of rays together
- Use any-hit for shadows:
inter.accept_any_intersection(true) - Compact BVH periodically: Reduces memory, improves traversal
- Stream geometry updates: Use refitting for animated scenes