Skip to content

MLIR Overview

Prereqs: LLVM IR Tour, Passes & Pipelines. MLIR builds on LLVM’s pass machinery; if those are in your head, MLIR is mostly new vocabulary.

LLVM IR is brilliant for what it is — a small, RISC-like, scalar-CPU-shaped instruction set the C/Rust/Swift world all share. But if you’ve spent any time around AI compilers, you’ve felt the mismatch: LLVM doesn’t know what a tensor is, doesn’t know what a loop is as a structured object, doesn’t know what a GPU thread block is. By the time your AI kernel reaches LLVM IR, the structural information that makes interesting optimizations possible — tile shapes, fusion opportunities, layout choices — is gone.

exists to keep that information around long enough to optimize against it. Same pass-pipeline philosophy as LLVM, but generalized: instead of one fixed instruction set, MLIR is a framework for representing programs at many levels of abstraction simultaneously, each a different . A linalg.matmul at the top, an scf.for in the middle, an nvgpu.wgmma near the bottom, an llvm.add at the floor. rewrites IR from a higher dialect to a lower one. Stack 5–10 of those and you have an AI compiler.

Every modern AI compiler — Triton, IREE, JAX/XLA, ExecuTorch, ONNX-MLIR, Modular MAX — is built on this idea. This lesson is the framework, not any one compiler.

TL;DR

  • MLIR = Multi-Level IR. One framework for representing programs at many levels of abstraction simultaneously: linear-algebra ops at the top, GPU ops in the middle, LLVM IR at the bottom. Every level is just a different dialect.
  • Every modern AI compiler — Triton, IREE, JAX/XLA, ExecuTorch, OpenXLA, ONNX-MLIR, Modular’s MAX — is built on MLIR. PyTorch’s torch.compile lowers through it (via Inductor → Triton → MLIR → LLVM).
  • The dialect is the unit of vocabulary. linalg, tensor, arith, scf (structured control flow), memref, gpu, nvgpu, vector, llvm are the canonical built-in dialects. Custom domains (TPU, NPU, your hardware) ship their own.
  • Lowering is the central operation: rewrite IR from a higher-level dialect to a lower one, repeatedly, until you’re at llvm and can hand off to LLVM. Every AI compiler is essentially a stack of lowerings.
  • MLIR is younger than LLVM (Google, 2019) but moves much faster. The API churns — code from 2022 often doesn’t compile against current main. Pin your version.

Mental model

The same IR file holds different dialects at each stage. Lowering = rewrite from the higher dialect to the lower one. Repeat 4–6 times.

A trivial MLIR program at three levels

A tiny matmul-add written first in linalg:

func.func @gemm_bias(%A: tensor<128x256xf32>, %B: tensor<256x64xf32>, %bias: tensor<128x64xf32>) -> tensor<128x64xf32> { %c0 = arith.constant 0.0 : f32 %init = tensor.empty() : tensor<128x64xf32> %zero = linalg.fill ins(%c0 : f32) outs(%init : tensor<128x64xf32>) -> tensor<128x64xf32> %mm = linalg.matmul ins(%A, %B : tensor<128x256xf32>, tensor<256x64xf32>) outs(%zero : tensor<128x64xf32>) -> tensor<128x64xf32> %out = linalg.add ins(%mm, %bias : tensor<128x64xf32>, tensor<128x64xf32>) outs(%init : tensor<128x64xf32>) -> tensor<128x64xf32> return %out : tensor<128x64xf32> }

Notice: tensors are first-class. Operations like linalg.matmul understand the math. %init is “fresh empty memory of this shape” — no concrete allocation yet.

After lowering through convert-linalg-to-loops, the same function looks like nested loops in scf (structured control flow):

func.func @gemm_bias(%A: memref<128x256xf32>, %B: memref<256x64xf32>, %bias: memref<128x64xf32>, %out: memref<128x64xf32>) { scf.for %i = 0 to 128 { scf.for %j = 0 to 64 { %acc_init = memref.load %bias[%i, %j] : memref<128x64xf32> %acc = scf.for %k = 0 to 256 iter_args(%a = %acc_init) -> f32 { %a_elem = memref.load %A[%i, %k] : memref<128x256xf32> %b_elem = memref.load %B[%k, %j] : memref<256x64xf32> %m = arith.mulf %a_elem, %b_elem : f32 %sum = arith.addf %a, %m : f32 scf.yield %sum : f32 } memref.store %acc, %out[%i, %j] : memref<128x64xf32> } } return }

Tensors became memrefs (concrete buffers). linalg.matmul exploded into three nested loops. We can now apply loop-level optimizations: tiling, vectorization, parallelization.

After more lowering (convert-scf-to-cf, convert-arith-to-llvm, convert-memref-to-llvm), the function becomes pure llvm dialect — basically LLVM IR wearing an MLIR jacket — and ready to hand off to the LLVM backend.

Same program, three different dialects. Each level kept around exactly the structure the next pass needed.

Dialects you’ll meet

DialectWhat it representsWhere it lives in the stack
tensorImmutable tensors as valuesTop — frontend handoff
linalgStructured linear-algebra ops over tensors / memrefsTop
funcFunctions, callsAll levels
arithScalar arithmeticAll levels
scfStructured control flow (for, while, if)Middle
memrefBuffers in memory with strided layoutMiddle
vectorSIMD-shaped operationsMiddle
affinePolyhedral-friendly loops with affine boundsMiddle (loop optimization)
gpuGPU kernel structure (grids, blocks, threads)Middle-low
nvgpuNVIDIA-specific (TMA, mma)Low
llvmLLVM IR as an MLIR dialectBottom

Custom hardware ships its own: AMD has rocdl, Apple has private dialects, every NPU vendor has theirs. The point of MLIR is adding a new dialect is cheap, so each piece of hardware represents itself in the level of detail it needs.

What torch.compile actually emits

Inside Inductor, after autograd and dispatch, the IR is FX (PyTorch’s own graph IR). Inductor lowers FX to a Triton kernel. Triton’s compiler lowers to MLIR (TritonGPU dialect → TritonNvidiaGPU → llvm), then to LLVM IR, then to PTX, then to SASS. Five lowerings. Each can be inspected:

import torch import torch._inductor.config as config config.trace.enabled = True @torch.compile def f(x, y): return torch.relu(x @ y) f(x, y) # Generated artifacts dropped under torch_compile_debug/ # Look for *.ll (LLVM IR), *.mlir, *.ptx, output_code.py.

This is the pipeline — six representations of the same kernel. MLIR is the middle three.

The IR has regions

LLVM IR has functions and basic blocks. MLIR generalizes: an operation can carry regions of nested IR. A scf.for is one operation whose region holds the loop body. A func.func is one operation whose region is the function body. This compositionality is what makes MLIR so flexible — your gpu.launch op can have a region that’s the kernel body, your custom tpu.matmul_block can have a region for the schedule.

The cost: API churn. The benefit: you can express any control structure as just-another-op.

mlir-opt — the workhorse

Like opt for LLVM, but for MLIR. The big move:

# Run a single lowering pass mlir-opt --convert-linalg-to-loops gemm.mlir # Compose lowerings mlir-opt --convert-linalg-to-loops --convert-scf-to-cf --convert-arith-to-llvm --convert-memref-to-llvm gemm.mlir # Or use a curated pipeline mlir-opt --pass-pipeline='builtin.module(convert-linalg-to-loops, ...)' gemm.mlir

mlir-opt --help lists every pass. Most have shape --convert-X-to-Y for X-to-Y lowerings, or --Y-vectorize, --Y-tile for transformations within a dialect.

Run it in your browser — toy dialect lowerer

Python — editableA miniature multi-level IR with two dialects (matmul and loop) and a lowering.
Ctrl+Enter to run

The shape — high-level op gets replaced by an equivalent tree of lower-level ops — is exactly what every MLIR conversion pass does, just with much richer types.

Quick check

Fill in the blank
In MLIR, the unit of vocabulary that bundles related operations and types together (e.g., linalg, tensor, gpu):
The thing that makes MLIR multi-level.
Quick check
Why does MLIR exist alongside LLVM IR rather than being absorbed into it?

Key takeaways

  1. MLIR = Multi-Level IR. One framework, many dialects. Every level is a dialect.
  2. Lowering is the core operation: rewrite IR from a higher dialect to a lower one. Repeat until LLVM.
  3. Every modern AI compiler runs on it: Triton, IREE, XLA, ExecuTorch, ONNX-MLIR, Modular MAX.
  4. Dialects you’ll meet: tensor and linalg (top), scf and memref (middle), gpu and nvgpu (low), llvm (bottom).
  5. API churns. Pin your LLVM/MLIR version. Tutorials more than 6 months old often don’t compile.

Go deeper

Prereqs: LLVM IR Tour, Passes & Pipelines. MLIR builds on LLVM’s pass machinery; if those are in your head, MLIR is mostly new vocabulary.

TL;DR

  • MLIR = Multi-Level IR. One framework for representing programs at many levels of abstraction simultaneously: linear-algebra ops at the top, GPU ops in the middle, LLVM IR at the bottom. Every level is just a different dialect.
  • Every modern AI compiler — Triton, IREE, JAX/XLA, ExecuTorch, OpenXLA, ONNX-MLIR, Modular’s MAX — is built on MLIR. PyTorch’s torch.compile lowers through it (via Inductor → Triton → MLIR → LLVM).
  • The dialect is the unit of vocabulary. linalg, tensor, arith, scf (structured control flow), memref, gpu, nvgpu, vector, llvm are the canonical built-in dialects. Custom domains (TPU, NPU, your hardware) ship their own.
  • Lowering is the central operation: rewrite IR from a higher-level dialect to a lower one, repeatedly, until you’re at llvm and can hand off to LLVM. Every AI compiler is essentially a stack of lowerings.
  • MLIR is younger than LLVM (Google, 2019) but moves much faster. The API churns — code from 2022 often doesn’t compile against current main. Pin your version.

Why this matters

LLVM IR is too low-level for AI: it doesn’t know about tensors, doesn’t know about loops as first-class objects, doesn’t know about GPU programming models. By the time your kernel is in LLVM IR, the structural information that made interesting optimizations possible (loop tiling, fusion, layout selection) is gone. MLIR exists to keep that information around long enough to optimize against it. Every AI compiler today builds on MLIR for exactly this reason.

If you want to write a custom hardware backend, do graph-level optimization, or even just understand what torch.compile is doing, the answer goes through MLIR.

Mental model

The same IR file holds different dialects at each stage. Lowering = rewrite from the higher dialect to the lower one. Repeat 4–6 times.

Concrete walkthrough

A trivial MLIR program at three levels

A tiny matmul-add written first in linalg:

func.func @gemm_bias(%A: tensor<128x256xf32>, %B: tensor<256x64xf32>, %bias: tensor<128x64xf32>) -> tensor<128x64xf32> { %c0 = arith.constant 0.0 : f32 %init = tensor.empty() : tensor<128x64xf32> %zero = linalg.fill ins(%c0 : f32) outs(%init : tensor<128x64xf32>) -> tensor<128x64xf32> %mm = linalg.matmul ins(%A, %B : tensor<128x256xf32>, tensor<256x64xf32>) outs(%zero : tensor<128x64xf32>) -> tensor<128x64xf32> %out = linalg.add ins(%mm, %bias : tensor<128x64xf32>, tensor<128x64xf32>) outs(%init : tensor<128x64xf32>) -> tensor<128x64xf32> return %out : tensor<128x64xf32> }

Notice: tensors are first-class. Operations like linalg.matmul understand the math. %init is “fresh empty memory of this shape” — no concrete allocation yet.

After lowering through convert-linalg-to-loops, the same function looks like nested loops in scf (structured control flow):

func.func @gemm_bias(%A: memref<128x256xf32>, %B: memref<256x64xf32>, %bias: memref<128x64xf32>, %out: memref<128x64xf32>) { scf.for %i = 0 to 128 { scf.for %j = 0 to 64 { %acc_init = memref.load %bias[%i, %j] : memref<128x64xf32> %acc = scf.for %k = 0 to 256 iter_args(%a = %acc_init) -> f32 { %a_elem = memref.load %A[%i, %k] : memref<128x256xf32> %b_elem = memref.load %B[%k, %j] : memref<256x64xf32> %m = arith.mulf %a_elem, %b_elem : f32 %sum = arith.addf %a, %m : f32 scf.yield %sum : f32 } memref.store %acc, %out[%i, %j] : memref<128x64xf32> } } return }

Tensors became memrefs (concrete buffers). linalg.matmul exploded into three nested loops. We can now apply loop-level optimizations: tiling, vectorization, parallelization.

After more lowering (convert-scf-to-cf, convert-arith-to-llvm, convert-memref-to-llvm), the function becomes pure llvm dialect — basically LLVM IR wearing an MLIR jacket — and ready to hand off to the LLVM backend.

Same program, three different dialects. Each level kept around exactly the structure the next pass needed.

Dialects you’ll meet

DialectWhat it representsWhere it lives in the stack
tensorImmutable tensors as valuesTop — frontend handoff
linalgStructured linear-algebra ops over tensors / memrefsTop
funcFunctions, callsAll levels
arithScalar arithmeticAll levels
scfStructured control flow (for, while, if)Middle
memrefBuffers in memory with strided layoutMiddle
vectorSIMD-shaped operationsMiddle
affinePolyhedral-friendly loops with affine boundsMiddle (loop optimization)
gpuGPU kernel structure (grids, blocks, threads)Middle-low
nvgpuNVIDIA-specific (TMA, mma)Low
llvmLLVM IR as an MLIR dialectBottom

Custom hardware ships its own: AMD has rocdl, Apple has private dialects, every NPU vendor has theirs. The point of MLIR is adding a new dialect is cheap, so each piece of hardware represents itself in the level of detail it needs.

What torch.compile actually emits

Inside Inductor, after autograd and dispatch, the IR is FX (PyTorch’s own graph IR). Inductor lowers FX to a Triton kernel. Triton’s compiler lowers to MLIR (TritonGPU dialect → TritonNvidiaGPU → llvm), then to LLVM IR, then to PTX, then to SASS. Five lowerings. Each can be inspected:

import torch import torch._inductor.config as config config.trace.enabled = True @torch.compile def f(x, y): return torch.relu(x @ y) f(x, y) # Generated artifacts dropped under torch_compile_debug/ # Look for *.ll (LLVM IR), *.mlir, *.ptx, output_code.py.

This is the pipeline — six representations of the same kernel. MLIR is the middle three.

The IR has regions

LLVM IR has functions and basic blocks. MLIR generalizes: an operation can carry regions of nested IR. A scf.for is one operation whose region holds the loop body. A func.func is one operation whose region is the function body. This compositionality is what makes MLIR so flexible — your gpu.launch op can have a region that’s the kernel body, your custom tpu.matmul_block can have a region for the schedule.

The cost: API churn. The benefit: you can express any control structure as just-another-op.

mlir-opt — the workhorse

Like opt for LLVM, but for MLIR. The big move:

# Run a single lowering pass mlir-opt --convert-linalg-to-loops gemm.mlir # Compose lowerings mlir-opt --convert-linalg-to-loops --convert-scf-to-cf --convert-arith-to-llvm --convert-memref-to-llvm gemm.mlir # Or use a curated pipeline mlir-opt --pass-pipeline='builtin.module(convert-linalg-to-loops, ...)' gemm.mlir

mlir-opt --help lists every pass. Most have shape --convert-X-to-Y for X-to-Y lowerings, or --Y-vectorize, --Y-tile for transformations within a dialect.

Run it in your browser — toy dialect lowerer

Python — editableA miniature multi-level IR with two dialects (matmul and loop) and a lowering.
Ctrl+Enter to run

The shape — high-level op gets replaced by an equivalent tree of lower-level ops — is exactly what every MLIR conversion pass does, just with much richer types.

Quick check

Fill in the blank
In MLIR, the unit of vocabulary that bundles related operations and types together (e.g., linalg, tensor, gpu):
The thing that makes MLIR multi-level.
Quick check
Why does MLIR exist alongside LLVM IR rather than being absorbed into it?

Key takeaways

  1. MLIR = Multi-Level IR. One framework, many dialects. Every level is a dialect.
  2. Lowering is the core operation: rewrite IR from a higher dialect to a lower one. Repeat until LLVM.
  3. Every modern AI compiler runs on it: Triton, IREE, XLA, ExecuTorch, ONNX-MLIR, Modular MAX.
  4. Dialects you’ll meet: tensor and linalg (top), scf and memref (middle), gpu and nvgpu (low), llvm (bottom).
  5. API churns. Pin your LLVM/MLIR version. Tutorials more than 6 months old often don’t compile.

Go deeper