MSL 4.0 – Cooperative Tensors & SIMD Matrix Operations
The definitive guide to Metal’s tensor acceleration primitives
The Neural Computing Revolution in Shaders
Metal Shading Language 4.0 introduces a paradigm shift: tensor operations are now first-class citizens within shader code. No longer confined to external frameworks, matrix multiplications and convolutions execute directly alongside your rendering logic.
Understanding Cooperative Tensors
Cooperative tensors represent a fundamental departure from traditional SIMD programming. Unlike standard vectors limited to 4 components, cooperative tensors span entire SIMD groups, enabling hardware-accelerated matrix operations.
The Execution Model Hierarchy
MSL 4.0 defines three execution scopes for tensor operations:
// Execution groups determine hardware utilization
enum class execution_group {
thread, // Single thread - maximum flexibility
simdgroup, // 32 threads - balanced performance
threadgroup // Full threadgroup - maximum throughput
};
Thread Execution: Use when tensor operations exhibit divergent control flow or operate on different data per thread. Maximum flexibility, minimum hardware acceleration.
SIMD-group Execution: The sweet spot for most neural network inference. All 32 threads in a SIMD-group collaborate on the same matrix operation. Significant hardware acceleration with reasonable flexibility.
Threadgroup Execution: Maximum throughput when entire threadgroups perform identical operations. Hardware tensor cores fully engaged.
Tensor Memory Layout
Cooperative tensors use specialized memory layouts optimized for matrix operations:
#include <metal_tensor>
using namespace metal::tensor_ops;
// Define tensor layout for 16x16 matrix tiles
using TensorLayout = tensor_layout<
float, // Element type
16, 16, // Dimensions
layout_order::row_major // Memory ordering
>;
// Create cooperative tensor in threadgroup memory
threadgroup TensorLayout::storage_type tile_storage;
Metal Performance Primitives Integration
The <MetalPerformancePrimitives/MetalPerformancePrimitives.h> header exposes optimized tensor operations directly in shader code:
MatMul2D – The Core Operation
Matrix multiplication is the backbone of neural inference:
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
kernel void neural_matmul(
device const half* weights [[buffer(0)]],
device const half* input [[buffer(1)]],
device half* output [[buffer(2)]],
uint2 tid [[thread_position_in_grid]],
uint simd_lane [[thread_index_in_simdgroup]]
) {
// Configure matrix multiplication descriptor
MatMul2DDescriptor desc;
desc.M = 256; // Output rows
desc.N = 128; // Output columns
desc.K = 512; // Inner dimension
desc.alpha = 1.0h; // Scaling factor
desc.beta = 0.0h; // Accumulation factor
desc.transA = false; // Transpose A
desc.transB = true; // Transpose B (common for weights)
// Execute with simdgroup cooperation
matmul2d<execution_simdgroup>(
desc,
weights,
input,
output,
tid,
simd_lane
);
}
Convolution2D – Spatial Feature Extraction
For convolutional neural networks, MSL 4.0 provides hardware-accelerated 2D convolution:
kernel void conv_layer(
device const half* input_features [[buffer(0)]],
device const half* kernel_weights [[buffer(1)]],
device half* output_features [[buffer(2)]],
constant ConvParams& params [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]
) {
Convolution2DDescriptor desc;
desc.inputWidth = params.input_w;
desc.inputHeight = params.input_h;
desc.inputChannels = params.in_channels;
desc.outputChannels = params.out_channels;
desc.kernelWidth = 3;
desc.kernelHeight = 3;
desc.strideX = 1;
desc.strideY = 1;
desc.paddingX = 1;
desc.paddingY = 1;
desc.dilationX = 1;
desc.dilationY = 1;
convolution2d<execution_simdgroup>(
desc,
input_features,
kernel_weights,
output_features,
tid
);
}
SIMD-Group Matrix Functions
Beyond cooperative tensors, MSL 4.0 enhances traditional SIMD-group matrix operations:
Creating and Loading Matrices
#include <metal_simdgroup_matrix>
kernel void simd_matrix_demo(
device const float* A [[buffer(0)]],
device const float* B [[buffer(1)]],
device float* C [[buffer(2)]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Declare 8x8 matrix tiles
simdgroup_float8x8 mat_a;
simdgroup_float8x8 mat_b;
simdgroup_float8x8 mat_c;
// Load from device memory with stride
simdgroup_load(mat_a, A, /*columns=*/64);
simdgroup_load(mat_b, B, /*columns=*/64);
// Initialize accumulator
mat_c = simdgroup_float8x8(0.0f);
// Multiply-accumulate
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);
// Store result
simdgroup_store(mat_c, C, /*columns=*/64);
}
Supported Matrix Dimensions
MSL 4.0 supports specific tile sizes optimized for Apple Silicon:
| Type | Dimensions | Hardware Support |
|---|---|---|
simdgroup_half8x8 | 8×8 | All Apple Silicon |
simdgroup_float8x8 | 8×8 | M1 and later |
simdgroup_bfloat8x8 | 8×8 | M4 and later |
Practical Example: Neural Material Decompression
Here’s a complete shader implementing neural texture decompression:
#include <metal_stdlib>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
using namespace metal;
struct NeuralMaterialWeights {
device const half* layer1_weights; // 16x64
device const half* layer1_bias; // 64
device const half* layer2_weights; // 64x64
device const half* layer2_bias; // 64
device const half* layer3_weights; // 64x4
device const half* layer3_bias; // 4 (RGBA output)
};
// ReLU activation inlined
inline half4 relu(half4 x) {
return max(x, half4(0.0h));
}
fragment half4 neural_material_fragment(
VertexOut in [[stage_in]],
constant NeuralMaterialWeights& weights [[buffer(0)]],
texture2d<half> latent_texture [[texture(0)]],
sampler tex_sampler [[sampler(0)]]
) {
// Sample compressed latent features (16 channels)
half4 latent0 = latent_texture.sample(tex_sampler, in.uv, 0);
half4 latent1 = latent_texture.sample(tex_sampler, in.uv, 1);
half4 latent2 = latent_texture.sample(tex_sampler, in.uv, 2);
half4 latent3 = latent_texture.sample(tex_sampler, in.uv, 3);
// Pack into input tensor (thread-local for fragment shader)
half input_features[16] = {
latent0.r, latent0.g, latent0.b, latent0.a,
latent1.r, latent1.g, latent1.b, latent1.a,
latent2.r, latent2.g, latent2.b, latent2.a,
latent3.r, latent3.g, latent3.b, latent3.a
};
// Layer 1: 16 → 64
half hidden1[64];
for (int i = 0; i < 64; i++) {
half sum = weights.layer1_bias[i];
for (int j = 0; j < 16; j++) {
sum += input_features[j] * weights.layer1_weights[j * 64 + i];
}
hidden1[i] = max(sum, 0.0h); // ReLU
}
// Layer 2: 64 → 64
half hidden2[64];
for (int i = 0; i < 64; i++) {
half sum = weights.layer2_bias[i];
for (int j = 0; j < 64; j++) {
sum += hidden1[j] * weights.layer2_weights[j * 64 + i];
}
hidden2[i] = max(sum, 0.0h); // ReLU
}
// Layer 3: 64 → 4 (RGBA)
half4 output;
for (int i = 0; i < 4; i++) {
half sum = weights.layer3_bias[i];
for (int j = 0; j < 64; j++) {
sum += hidden2[j] * weights.layer3_weights[j * 4 + i];
}
output[i] = sum;
}
return saturate(output); // Clamp to [0,1]
}
Performance Optimization Strategies
1. Choose Appropriate Execution Group
// WRONG: Using thread execution for uniform operations
// All threads doing same computation - wasting parallelism
for (int i = 0; i < 64; i++) {
result[i] = dot(weights[i], input);
}
// CORRECT: Using simdgroup execution
// Hardware multiplies entire matrix in one operation
matmul2d<execution_simdgroup>(desc, weights, input, result, tid, simd_lane);
2. Tile for Cache Efficiency
// Process in tiles that fit in register file
constant uint TILE_M = 32;
constant uint TILE_N = 32;
constant uint TILE_K = 16;
for (uint k = 0; k < K; k += TILE_K) {
// Load tile of A into registers
simdgroup_load(tile_a, A + row * K + k, K);
// Load tile of B into registers
simdgroup_load(tile_b, B + k * N + col, N);
// Accumulate
simdgroup_multiply_accumulate(acc, tile_a, tile_b, acc);
}
3. Use Half Precision When Possible
// Half precision: 2x throughput, 2x memory bandwidth
simdgroup_half8x8 mat_a, mat_b, mat_c;
// Only use float when precision critical
simdgroup_float8x8 high_precision_accumulator;
Hardware Considerations
Apple Silicon Matrix Throughput
| Chip | half8x8 TFLOPS | float8x8 TFLOPS | Neural Engine TOPS |
|---|---|---|---|
| M1 | 2.6 | 2.6 | 11 |
| M2 | 3.6 | 3.6 | 15.8 |
| M3 | 4.1 | 4.1 | 18 |
| M4 | 4.5 | 4.5 | 38 |
When to Use Shader ML vs Neural Engine
Use Shader ML (GPU):
- Small networks (<1M parameters)
- Tight integration with rendering
- Per-pixel inference
- Real-time requirements
Use Neural Engine:
- Large networks (>10M parameters)
- Batch inference
- Training workloads
- Maximum efficiency