MSL 4.0 – Cooperative Tensors & SIMD Matrix Operations
MSL 4.0

MSL 4.0 – Cooperative Tensors & SIMD Matrix Operations

The definitive guide to Metal's tensor acceleration primitives

The definitive guide to Metal’s tensor acceleration primitives

The Neural Computing Revolution in Shaders

Metal Shading Language 4.0 introduces a paradigm shift: tensor operations are now first-class citizens within shader code. No longer confined to external frameworks, matrix multiplications and convolutions execute directly alongside your rendering logic.

Understanding Cooperative Tensors

Cooperative tensors represent a fundamental departure from traditional SIMD programming. Unlike standard vectors limited to 4 components, cooperative tensors span entire SIMD groups, enabling hardware-accelerated matrix operations.

The Execution Model Hierarchy

MSL 4.0 defines three execution scopes for tensor operations:

// Execution groups determine hardware utilization
enum class execution_group {
    thread,        // Single thread - maximum flexibility
    simdgroup,     // 32 threads - balanced performance
    threadgroup    // Full threadgroup - maximum throughput
};

Thread Execution: Use when tensor operations exhibit divergent control flow or operate on different data per thread. Maximum flexibility, minimum hardware acceleration.

SIMD-group Execution: The sweet spot for most neural network inference. All 32 threads in a SIMD-group collaborate on the same matrix operation. Significant hardware acceleration with reasonable flexibility.

Threadgroup Execution: Maximum throughput when entire threadgroups perform identical operations. Hardware tensor cores fully engaged.

Tensor Memory Layout

Cooperative tensors use specialized memory layouts optimized for matrix operations:

#include <metal_tensor>
using namespace metal::tensor_ops;

// Define tensor layout for 16x16 matrix tiles
using TensorLayout = tensor_layout<
    float,                    // Element type
    16, 16,                   // Dimensions
    layout_order::row_major   // Memory ordering
>;

// Create cooperative tensor in threadgroup memory
threadgroup TensorLayout::storage_type tile_storage;

Metal Performance Primitives Integration

The <MetalPerformancePrimitives/MetalPerformancePrimitives.h> header exposes optimized tensor operations directly in shader code:

MatMul2D – The Core Operation

Matrix multiplication is the backbone of neural inference:

#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>

kernel void neural_matmul(
    device const half* weights [[buffer(0)]],
    device const half* input   [[buffer(1)]],
    device half* output        [[buffer(2)]],
    uint2 tid [[thread_position_in_grid]],
    uint simd_lane [[thread_index_in_simdgroup]]
) {
    // Configure matrix multiplication descriptor
    MatMul2DDescriptor desc;
    desc.M = 256;           // Output rows
    desc.N = 128;           // Output columns
    desc.K = 512;           // Inner dimension
    desc.alpha = 1.0h;      // Scaling factor
    desc.beta = 0.0h;       // Accumulation factor
    desc.transA = false;    // Transpose A
    desc.transB = true;     // Transpose B (common for weights)

    // Execute with simdgroup cooperation
    matmul2d<execution_simdgroup>(
        desc,
        weights,
        input,
        output,
        tid,
        simd_lane
    );
}

Convolution2D – Spatial Feature Extraction

For convolutional neural networks, MSL 4.0 provides hardware-accelerated 2D convolution:

kernel void conv_layer(
    device const half* input_features  [[buffer(0)]],
    device const half* kernel_weights  [[buffer(1)]],
    device half* output_features       [[buffer(2)]],
    constant ConvParams& params        [[buffer(3)]],
    uint3 tid [[thread_position_in_grid]]
) {
    Convolution2DDescriptor desc;
    desc.inputWidth = params.input_w;
    desc.inputHeight = params.input_h;
    desc.inputChannels = params.in_channels;
    desc.outputChannels = params.out_channels;
    desc.kernelWidth = 3;
    desc.kernelHeight = 3;
    desc.strideX = 1;
    desc.strideY = 1;
    desc.paddingX = 1;
    desc.paddingY = 1;
    desc.dilationX = 1;
    desc.dilationY = 1;

    convolution2d<execution_simdgroup>(
        desc,
        input_features,
        kernel_weights,
        output_features,
        tid
    );
}

SIMD-Group Matrix Functions

Beyond cooperative tensors, MSL 4.0 enhances traditional SIMD-group matrix operations:

Creating and Loading Matrices

#include <metal_simdgroup_matrix>

kernel void simd_matrix_demo(
    device const float* A [[buffer(0)]],
    device const float* B [[buffer(1)]],
    device float* C       [[buffer(2)]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_id   [[simdgroup_index_in_threadgroup]]
) {
    // Declare 8x8 matrix tiles
    simdgroup_float8x8 mat_a;
    simdgroup_float8x8 mat_b;
    simdgroup_float8x8 mat_c;

    // Load from device memory with stride
    simdgroup_load(mat_a, A, /*columns=*/64);
    simdgroup_load(mat_b, B, /*columns=*/64);

    // Initialize accumulator
    mat_c = simdgroup_float8x8(0.0f);

    // Multiply-accumulate
    simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);

    // Store result
    simdgroup_store(mat_c, C, /*columns=*/64);
}

Supported Matrix Dimensions

MSL 4.0 supports specific tile sizes optimized for Apple Silicon:

TypeDimensionsHardware Support
simdgroup_half8x88×8All Apple Silicon
simdgroup_float8x88×8M1 and later
simdgroup_bfloat8x88×8M4 and later

Practical Example: Neural Material Decompression

Here’s a complete shader implementing neural texture decompression:

#include <metal_stdlib>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
using namespace metal;

struct NeuralMaterialWeights {
    device const half* layer1_weights;  // 16x64
    device const half* layer1_bias;     // 64
    device const half* layer2_weights;  // 64x64
    device const half* layer2_bias;     // 64
    device const half* layer3_weights;  // 64x4
    device const half* layer3_bias;     // 4 (RGBA output)
};

// ReLU activation inlined
inline half4 relu(half4 x) {
    return max(x, half4(0.0h));
}

fragment half4 neural_material_fragment(
    VertexOut in [[stage_in]],
    constant NeuralMaterialWeights& weights [[buffer(0)]],
    texture2d<half> latent_texture [[texture(0)]],
    sampler tex_sampler [[sampler(0)]]
) {
    // Sample compressed latent features (16 channels)
    half4 latent0 = latent_texture.sample(tex_sampler, in.uv, 0);
    half4 latent1 = latent_texture.sample(tex_sampler, in.uv, 1);
    half4 latent2 = latent_texture.sample(tex_sampler, in.uv, 2);
    half4 latent3 = latent_texture.sample(tex_sampler, in.uv, 3);

    // Pack into input tensor (thread-local for fragment shader)
    half input_features[16] = {
        latent0.r, latent0.g, latent0.b, latent0.a,
        latent1.r, latent1.g, latent1.b, latent1.a,
        latent2.r, latent2.g, latent2.b, latent2.a,
        latent3.r, latent3.g, latent3.b, latent3.a
    };

    // Layer 1: 16 → 64
    half hidden1[64];
    for (int i = 0; i < 64; i++) {
        half sum = weights.layer1_bias[i];
        for (int j = 0; j < 16; j++) {
            sum += input_features[j] * weights.layer1_weights[j * 64 + i];
        }
        hidden1[i] = max(sum, 0.0h); // ReLU
    }

    // Layer 2: 64 → 64
    half hidden2[64];
    for (int i = 0; i < 64; i++) {
        half sum = weights.layer2_bias[i];
        for (int j = 0; j < 64; j++) {
            sum += hidden1[j] * weights.layer2_weights[j * 64 + i];
        }
        hidden2[i] = max(sum, 0.0h); // ReLU
    }

    // Layer 3: 64 → 4 (RGBA)
    half4 output;
    for (int i = 0; i < 4; i++) {
        half sum = weights.layer3_bias[i];
        for (int j = 0; j < 64; j++) {
            sum += hidden2[j] * weights.layer3_weights[j * 4 + i];
        }
        output[i] = sum;
    }

    return saturate(output); // Clamp to [0,1]
}

Performance Optimization Strategies

1. Choose Appropriate Execution Group

// WRONG: Using thread execution for uniform operations
// All threads doing same computation - wasting parallelism
for (int i = 0; i < 64; i++) {
    result[i] = dot(weights[i], input);
}

// CORRECT: Using simdgroup execution
// Hardware multiplies entire matrix in one operation
matmul2d<execution_simdgroup>(desc, weights, input, result, tid, simd_lane);

2. Tile for Cache Efficiency

// Process in tiles that fit in register file
constant uint TILE_M = 32;
constant uint TILE_N = 32;
constant uint TILE_K = 16;

for (uint k = 0; k < K; k += TILE_K) {
    // Load tile of A into registers
    simdgroup_load(tile_a, A + row * K + k, K);
    // Load tile of B into registers
    simdgroup_load(tile_b, B + k * N + col, N);
    // Accumulate
    simdgroup_multiply_accumulate(acc, tile_a, tile_b, acc);
}

3. Use Half Precision When Possible

// Half precision: 2x throughput, 2x memory bandwidth
simdgroup_half8x8 mat_a, mat_b, mat_c;

// Only use float when precision critical
simdgroup_float8x8 high_precision_accumulator;

Hardware Considerations

Apple Silicon Matrix Throughput

Chiphalf8x8 TFLOPSfloat8x8 TFLOPSNeural Engine TOPS
M12.62.611
M23.63.615.8
M34.14.118
M44.54.538

When to Use Shader ML vs Neural Engine

Use Shader ML (GPU):

  • Small networks (<1M parameters)
  • Tight integration with rendering
  • Per-pixel inference
  • Real-time requirements

Use Neural Engine:

  • Large networks (>10M parameters)
  • Batch inference
  • Training workloads
  • Maximum efficiency