Skip to content

JAX & Pallas

Prereqs: Triton, MLIR Overview, Operator Fusion. Pallas is “Triton, but for the JAX/XLA stack.”

PyTorch dominates research-volume in 2026; JAX dominates TPU work and a meaningful slice of frontier training (Anthropic, Google, parts of OpenAI). The JAX/XLA stack is the other world of AI compilation — same compiler concepts as PyTorch’s + , but with a different IR (), a different graph compiler (), a different kernel DSL (), and a TPU backend that PyTorch simply doesn’t have.

The mental model: jax.jit traces a function to StableHLO; XLA compiles it to GPU/TPU/CPU; Pallas is the kernel escape hatch when XLA’s fusion isn’t enough — Triton-shaped, but the same source emits Triton on GPU and Mosaic on TPU. Knowing how this stack differs from Inductor is the price of admission for working in any JAX-using lab. The good news: most concepts (fusion, lowering, autotune) are the same words; only the spelling changes.

TL;DR

  • JAX uses a different compiler stack than PyTorch. jit-decorated functions trace to StableHLO (a stable subset of HLO, the XLA op set), then XLA compiles to GPU/TPU.
  • XLA is the long-standing graph compiler (Google, ~2017+). It does fusion, layout assignment, code-gen for GPU/TPU/CPU. It’s the production compiler that runs Gemini, every Google production model, and JAX-based research everywhere.
  • Pallas is JAX’s kernel DSL — Triton-like, but inside JAX. You write a JAX function, decorate it with pallas.pallas_call, and it lowers to a hand-written kernel for GPU (via Triton) or TPU (via Mosaic, a TPU-specific kernel emitter).
  • The single most distinctive thing about Pallas: the same kernel source can target both GPU and TPU. The dtype/layout primitives are the same; the backend chooses how to lower.
  • For 2026 production: JAX is the dominant choice for TPU work, neck-and-neck with PyTorch on GPU. Pallas is the kernel escape hatch — used heavily inside Google for Gemini training, increasingly used externally for high-perf JAX kernels.

Mental model

XLA is the umbrella; StableHLO is its IR; Pallas is the user-written-kernel escape hatch that hooks into the same lowering chain.

A JAX function and what XLA does to it

import jax import jax.numpy as jnp @jax.jit def f(x, w, b): return jax.nn.relu(x @ w + b) # First call: trace + compile (slow). Subsequent calls: cached, fast. out = f(jnp.ones((128, 256)), jnp.ones((256, 64)), jnp.zeros(64))

What happens in order:

  1. JAX’s tracer runs f with abstract values, producing a jaxpr (JAX’s internal IR).
  2. The jaxpr lowers to StableHLOdot, add, maximum (relu).
  3. XLA receives the StableHLO module. Runs ~100 passes: , layout assignment, autotune, codegen.
  4. Output: a binary kernel cached for this (input_shapes, dtypes) signature.
  5. Subsequent calls of f with matching signatures: invoke the cached kernel directly.

You can dump every stage:

print(jax.make_jaxpr(f)(x, w, b)) # jaxpr print(jax.jit(f).lower(x, w, b).compile().as_text()) # compiled HLO

The HLO dump is the JAX equivalent of torch.compile’s output_code.py.

XLA fusion vs Inductor fusion

Both fuse pointwise ops. Differences:

Inductor (PyTorch)XLA (JAX/TF)
Default kernel codegenTritonLLVM (CPU/GPU) + Mosaic (TPU)
AggressionGreedy, often correctMore conservative, but very strong on TPU
Custom kerneltorch.library.custom_op + TritonPallas
Ahead-of-time compilationLimitedFirst-class (XLA is AOT-by-default)
Cross-platformGPU-mostlyGPU + TPU + CPU equally

A subtle structural difference: XLA expects shapes at compile time (concrete or abstract via a tracing_shape). PyTorch dynamism is more flexible — a torch.compile’d function gracefully handles dynamic shapes via Dynamo’s guards. XLA pays for its perf on TPU with this rigidity.

Pallas — the kernel escape hatch

When XLA’s fusion doesn’t get you the performance you need, Pallas lets you write the kernel directly:

import jax import jax.numpy as jnp from jax.experimental import pallas as pl BM, BN, BK = 128, 128, 64 # tile sizes for the matmul def matmul_kernel(x_ref, w_ref, o_ref): """Inner kernel: each program computes one (BM, BN) output tile by looping over K in chunks of BK and accumulating.""" K = x_ref.shape[1] acc = jnp.zeros((BM, BN), dtype=jnp.float32) for k in range(0, K, BK): x_tile = pl.load(x_ref, (slice(None), pl.ds(k, BK))) # (BM, BK) w_tile = pl.load(w_ref, (pl.ds(k, BK), slice(None))) # (BK, BN) acc += jnp.dot(x_tile, w_tile, preferred_element_type=jnp.float32) o_ref[:, :] = acc.astype(o_ref.dtype) @jax.jit def matmul(x, w): M, K = x.shape _, N = w.shape return pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((M, N), x.dtype), grid=(M // BM, N // BN), # one program per (M, N) tile in_specs=[pl.BlockSpec((BM, K), lambda i, j: (i, 0)), # full row strip of x pl.BlockSpec((K, BN), lambda i, j: (0, j))], # full col strip of w out_specs=pl.BlockSpec((BM, BN), lambda i, j: (i, j)), )(x, w)

The structure is exactly Triton’s: a kernel function with an explicit K-loop accumulator, a grid spec, block specs that describe slicing. The Python is JAX-flavored (uses jax.ShapeDtypeStruct), but the model — outer 2D grid over (M, N) tiles, inner sequential K-loop, FP32 accumulator — matches the canonical Triton GEMM line-for-line.

The gift: this same kernel runs on GPU (via Triton) and TPU (via Mosaic). On TPU you get tile sizes that map to the systolic array; on GPU you get tensor-core mma. Same source.

The JAX → TPU story

Pallas TPU support is the reason JAX is the production language inside Google. TPU has a fundamentally different programming model from GPU — wide systolic arrays, explicit MXU (matrix unit) tiles, no thread-level parallelism in the GPU sense. Pallas-on-TPU compiles via Mosaic, an MLIR-based TPU kernel emitter. The Pallas program describes a tile-level computation; Mosaic emits TPU-native code that schedules the systolic array.

For non-TPU work, the choice between Triton and Pallas is mostly stylistic — both are tile-DSLs over Triton’s lowering chain on GPU. For TPU work, Pallas is the only ergonomic option outside writing raw XLA/HLO.

Picking JAX vs PyTorch

Decision tree, brutally:

  • TPU pod, frontier training, Google ecosystem → JAX.
  • GPU-only, fast iteration, large existing PyTorch codebase → PyTorch + torch.compile.
  • Research where you need the absolute lowest-overhead trace → JAX (the AOT-compile model has lower per-call overhead than Dynamo).
  • Production serving with tight latency → PyTorch / vLLM (better OSS serving stack maturity in 2026).

Both are excellent. The skill of moving between them is increasingly common; the underlying compiler concepts (fusion, lowering, kernel escape hatches) are the same.

Run it in your browser — XLA-style HLO simulator

Python — editableA tiny HLO-shaped IR with fusion. Lower a JAX-style program to HLO ops, then fuse pointwise.
Ctrl+Enter to run

The shape — pointwise chains collapse into one fused op — is exactly what XLA’s Fusion pass produces in real HLO dumps.

Quick check

Fill in the blank
The JAX kernel-DSL that compiles to Triton on GPU and Mosaic on TPU:
Greek goddess of wisdom and warfare; appropriate for kernel work.
Quick check
An engineer wants to write a custom flash-attention variant that runs identically on both H100 and TPU v5p. The right tool:

Key takeaways

  1. JAX → StableHLO → XLA → GPU/TPU/CPU. The mirror image of PyTorch’s Dynamo → Inductor → Triton.
  2. XLA fusion is conservative but powerful, especially on TPU. Inductor is greedier; the trade is GPU-mostly vs cross-platform.
  3. Pallas is JAX’s kernel escape hatch — Triton-shaped, with a TPU backend (Mosaic). Same source, two targets.
  4. JAX wins for TPU and frontier training inside Google ecosystem; PyTorch wins for fast iteration on GPU. Both are real careers in 2026.
  5. The compiler concepts transfer. Fusion, lowering, autotune are the same words in both stacks; only the spelling changes.

Go deeper

Prereqs: Triton, MLIR Overview, Operator Fusion. Pallas is “Triton, but for the JAX/XLA stack.”

TL;DR

  • JAX uses a different compiler stack than PyTorch. jit-decorated functions trace to StableHLO (a stable subset of HLO, the XLA op set), then XLA compiles to GPU/TPU.
  • XLA is the long-standing graph compiler (Google, ~2017+). It does fusion, layout assignment, code-gen for GPU/TPU/CPU. It’s the production compiler that runs Gemini, every Google production model, and JAX-based research everywhere.
  • Pallas is JAX’s kernel DSL — Triton-like, but inside JAX. You write a JAX function, decorate it with pallas.pallas_call, and it lowers to a hand-written kernel for GPU (via Triton) or TPU (via Mosaic, a TPU-specific kernel emitter).
  • The single most distinctive thing about Pallas: the same kernel source can target both GPU and TPU. The dtype/layout primitives are the same; the backend chooses how to lower.
  • For 2026 production: JAX is the dominant choice for TPU work, neck-and-neck with PyTorch on GPU. Pallas is the kernel escape hatch — used heavily inside Google for Gemini training, increasingly used externally for high-perf JAX kernels.

Why this matters

The JAX/XLA stack is the other world of AI compilation. PyTorch dominates research-volume in 2026; JAX dominates TPU work and a meaningful slice of frontier training (Anthropic, Google, parts of OpenAI). Knowing how XLA differs from Inductor — and how Pallas differs from Triton — is the price of admission for working in any JAX-using lab.

Mental model

XLA is the umbrella; StableHLO is its IR; Pallas is the user-written-kernel escape hatch that hooks into the same lowering chain.

Concrete walkthrough

A JAX function and what XLA does to it

import jax import jax.numpy as jnp @jax.jit def f(x, w, b): return jax.nn.relu(x @ w + b) # First call: trace + compile (slow). Subsequent calls: cached, fast. out = f(jnp.ones((128, 256)), jnp.ones((256, 64)), jnp.zeros(64))

What happens in order:

  1. JAX’s tracer runs f with abstract values, producing a jaxpr (JAX’s internal IR).
  2. The jaxpr lowers to StableHLOdot, add, maximum (relu).
  3. XLA receives the StableHLO module. Runs ~100 passes: fusion, layout assignment, autotune, codegen.
  4. Output: a binary kernel cached for this (input_shapes, dtypes) signature.
  5. Subsequent calls of f with matching signatures: invoke the cached kernel directly.

You can dump every stage:

print(jax.make_jaxpr(f)(x, w, b)) # jaxpr print(jax.jit(f).lower(x, w, b).compile().as_text()) # compiled HLO

The HLO dump is the JAX equivalent of torch.compile’s output_code.py.

XLA fusion vs Inductor fusion

Both fuse pointwise ops. Differences:

Inductor (PyTorch)XLA (JAX/TF)
Default kernel codegenTritonLLVM (CPU/GPU) + Mosaic (TPU)
AggressionGreedy, often correctMore conservative, but very strong on TPU
Custom kerneltorch.library.custom_op + TritonPallas
Ahead-of-time compilationLimitedFirst-class (XLA is AOT-by-default)
Cross-platformGPU-mostlyGPU + TPU + CPU equally

A subtle structural difference: XLA expects shapes at compile time (concrete or abstract via a tracing_shape). PyTorch dynamism is more flexible — a torch.compile’d function gracefully handles dynamic shapes via Dynamo’s guards. XLA pays for its perf on TPU with this rigidity.

Pallas — the kernel escape hatch

When XLA’s fusion doesn’t get you the performance you need, Pallas lets you write the kernel directly:

import jax import jax.numpy as jnp from jax.experimental import pallas as pl BM, BN, BK = 128, 128, 64 # tile sizes for the matmul def matmul_kernel(x_ref, w_ref, o_ref): """Inner kernel: each program computes one (BM, BN) output tile by looping over K in chunks of BK and accumulating.""" K = x_ref.shape[1] acc = jnp.zeros((BM, BN), dtype=jnp.float32) for k in range(0, K, BK): x_tile = pl.load(x_ref, (slice(None), pl.ds(k, BK))) # (BM, BK) w_tile = pl.load(w_ref, (pl.ds(k, BK), slice(None))) # (BK, BN) acc += jnp.dot(x_tile, w_tile, preferred_element_type=jnp.float32) o_ref[:, :] = acc.astype(o_ref.dtype) @jax.jit def matmul(x, w): M, K = x.shape _, N = w.shape return pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((M, N), x.dtype), grid=(M // BM, N // BN), # one program per (M, N) tile in_specs=[pl.BlockSpec((BM, K), lambda i, j: (i, 0)), # full row strip of x pl.BlockSpec((K, BN), lambda i, j: (0, j))], # full col strip of w out_specs=pl.BlockSpec((BM, BN), lambda i, j: (i, j)), )(x, w)

The structure is exactly Triton’s: a kernel function with an explicit K-loop accumulator, a grid spec, block specs that describe slicing. The Python is JAX-flavored (uses jax.ShapeDtypeStruct), but the model — outer 2D grid over (M, N) tiles, inner sequential K-loop, FP32 accumulator — matches the canonical Triton GEMM line-for-line.

The gift: this same kernel runs on GPU (via Triton) and TPU (via Mosaic). On TPU you get tile sizes that map to the systolic array; on GPU you get tensor-core mma. Same source.

The JAX → TPU story

Pallas TPU support is the reason JAX is the production language inside Google. TPU has a fundamentally different programming model from GPU — wide systolic arrays, explicit MXU (matrix unit) tiles, no thread-level parallelism in the GPU sense. Pallas-on-TPU compiles via Mosaic, an MLIR-based TPU kernel emitter. The Pallas program describes a tile-level computation; Mosaic emits TPU-native code that schedules the systolic array.

For non-TPU work, the choice between Triton and Pallas is mostly stylistic — both are tile-DSLs over Triton’s lowering chain on GPU. For TPU work, Pallas is the only ergonomic option outside writing raw XLA/HLO.

Picking JAX vs PyTorch

Decision tree, brutally:

  • TPU pod, frontier training, Google ecosystem → JAX.
  • GPU-only, fast iteration, large existing PyTorch codebase → PyTorch + torch.compile.
  • Research where you need the absolute lowest-overhead trace → JAX (the AOT-compile model has lower per-call overhead than Dynamo).
  • Production serving with tight latency → PyTorch / vLLM (better OSS serving stack maturity in 2026).

Both are excellent. The skill of moving between them is increasingly common; the underlying compiler concepts (fusion, lowering, kernel escape hatches) are the same.

Run it in your browser — XLA-style HLO simulator

Python — editableA tiny HLO-shaped IR with fusion. Lower a JAX-style program to HLO ops, then fuse pointwise.
Ctrl+Enter to run

The shape — pointwise chains collapse into one fused op — is exactly what XLA’s Fusion pass produces in real HLO dumps.

Quick check

Fill in the blank
The JAX kernel-DSL that compiles to Triton on GPU and Mosaic on TPU:
Greek goddess of wisdom and warfare; appropriate for kernel work.
Quick check
An engineer wants to write a custom flash-attention variant that runs identically on both H100 and TPU v5p. The right tool:

Key takeaways

  1. JAX → StableHLO → XLA → GPU/TPU/CPU. The mirror image of PyTorch’s Dynamo → Inductor → Triton.
  2. XLA fusion is conservative but powerful, especially on TPU. Inductor is greedier; the trade is GPU-mostly vs cross-platform.
  3. Pallas is JAX’s kernel escape hatch — Triton-shaped, with a TPU backend (Mosaic). Same source, two targets.
  4. JAX wins for TPU and frontier training inside Google ecosystem; PyTorch wins for fast iteration on GPU. Both are real careers in 2026.
  5. The compiler concepts transfer. Fusion, lowering, autotune are the same words in both stacks; only the spelling changes.

Go deeper