JAX & Pallas
Prereqs: Triton, MLIR Overview, Operator Fusion. Pallas is “Triton, but for the JAX/XLA stack.”
PyTorch dominates research-volume in 2026; JAX dominates TPU work and a meaningful slice of frontier training (Anthropic, Google, parts of OpenAI). The JAX/XLA stack is the other world of AI compilation — same compiler concepts as PyTorch’s + , but with a different IR (), a different graph compiler (), a different kernel DSL (), and a TPU backend that PyTorch simply doesn’t have.
The mental model: jax.jit traces a function to StableHLO; XLA compiles it to GPU/TPU/CPU; Pallas is the kernel escape hatch when XLA’s fusion isn’t enough — Triton-shaped, but the same source emits Triton on GPU and Mosaic on TPU. Knowing how this stack differs from Inductor is the price of admission for working in any JAX-using lab. The good news: most concepts (fusion, lowering, autotune) are the same words; only the spelling changes.
TL;DR
- JAX uses a different compiler stack than PyTorch.
jit-decorated functions trace to StableHLO (a stable subset of HLO, the XLA op set), then XLA compiles to GPU/TPU. - XLA is the long-standing graph compiler (Google, ~2017+). It does fusion, layout assignment, code-gen for GPU/TPU/CPU. It’s the production compiler that runs Gemini, every Google production model, and JAX-based research everywhere.
- Pallas is JAX’s kernel DSL — Triton-like, but inside JAX. You write a JAX function, decorate it with
pallas.pallas_call, and it lowers to a hand-written kernel for GPU (via Triton) or TPU (via Mosaic, a TPU-specific kernel emitter). - The single most distinctive thing about Pallas: the same kernel source can target both GPU and TPU. The dtype/layout primitives are the same; the backend chooses how to lower.
- For 2026 production: JAX is the dominant choice for TPU work, neck-and-neck with PyTorch on GPU. Pallas is the kernel escape hatch — used heavily inside Google for Gemini training, increasingly used externally for high-perf JAX kernels.
Mental model
XLA is the umbrella; StableHLO is its IR; Pallas is the user-written-kernel escape hatch that hooks into the same lowering chain.
A JAX function and what XLA does to it
import jax
import jax.numpy as jnp
@jax.jit
def f(x, w, b):
return jax.nn.relu(x @ w + b)
# First call: trace + compile (slow). Subsequent calls: cached, fast.
out = f(jnp.ones((128, 256)), jnp.ones((256, 64)), jnp.zeros(64))What happens in order:
- JAX’s tracer runs
fwith abstract values, producing a jaxpr (JAX’s internal IR). - The jaxpr lowers to StableHLO —
dot,add,maximum(relu). - XLA receives the StableHLO module. Runs ~100 passes: , layout assignment, autotune, codegen.
- Output: a binary kernel cached for this
(input_shapes, dtypes)signature. - Subsequent calls of
fwith matching signatures: invoke the cached kernel directly.
You can dump every stage:
print(jax.make_jaxpr(f)(x, w, b)) # jaxpr
print(jax.jit(f).lower(x, w, b).compile().as_text()) # compiled HLOThe HLO dump is the JAX equivalent of torch.compile’s output_code.py.
XLA fusion vs Inductor fusion
Both fuse pointwise ops. Differences:
| Inductor (PyTorch) | XLA (JAX/TF) | |
|---|---|---|
| Default kernel codegen | Triton | LLVM (CPU/GPU) + Mosaic (TPU) |
| Aggression | Greedy, often correct | More conservative, but very strong on TPU |
| Custom kernel | torch.library.custom_op + Triton | Pallas |
| Ahead-of-time compilation | Limited | First-class (XLA is AOT-by-default) |
| Cross-platform | GPU-mostly | GPU + TPU + CPU equally |
A subtle structural difference: XLA expects shapes at compile time (concrete or abstract via a tracing_shape). PyTorch dynamism is more flexible — a torch.compile’d function gracefully handles dynamic shapes via Dynamo’s guards. XLA pays for its perf on TPU with this rigidity.
Pallas — the kernel escape hatch
When XLA’s fusion doesn’t get you the performance you need, Pallas lets you write the kernel directly:
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
BM, BN, BK = 128, 128, 64 # tile sizes for the matmul
def matmul_kernel(x_ref, w_ref, o_ref):
"""Inner kernel: each program computes one (BM, BN) output tile by
looping over K in chunks of BK and accumulating."""
K = x_ref.shape[1]
acc = jnp.zeros((BM, BN), dtype=jnp.float32)
for k in range(0, K, BK):
x_tile = pl.load(x_ref, (slice(None), pl.ds(k, BK))) # (BM, BK)
w_tile = pl.load(w_ref, (pl.ds(k, BK), slice(None))) # (BK, BN)
acc += jnp.dot(x_tile, w_tile, preferred_element_type=jnp.float32)
o_ref[:, :] = acc.astype(o_ref.dtype)
@jax.jit
def matmul(x, w):
M, K = x.shape
_, N = w.shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((M, N), x.dtype),
grid=(M // BM, N // BN), # one program per (M, N) tile
in_specs=[pl.BlockSpec((BM, K), lambda i, j: (i, 0)), # full row strip of x
pl.BlockSpec((K, BN), lambda i, j: (0, j))], # full col strip of w
out_specs=pl.BlockSpec((BM, BN), lambda i, j: (i, j)),
)(x, w)The structure is exactly Triton’s: a kernel function with an explicit K-loop accumulator, a grid spec, block specs that describe slicing. The Python is JAX-flavored (uses jax.ShapeDtypeStruct), but the model — outer 2D grid over (M, N) tiles, inner sequential K-loop, FP32 accumulator — matches the canonical Triton GEMM line-for-line.
The gift: this same kernel runs on GPU (via Triton) and TPU (via Mosaic). On TPU you get tile sizes that map to the systolic array; on GPU you get tensor-core mma. Same source.
The JAX → TPU story
Pallas TPU support is the reason JAX is the production language inside Google. TPU has a fundamentally different programming model from GPU — wide systolic arrays, explicit MXU (matrix unit) tiles, no thread-level parallelism in the GPU sense. Pallas-on-TPU compiles via Mosaic, an MLIR-based TPU kernel emitter. The Pallas program describes a tile-level computation; Mosaic emits TPU-native code that schedules the systolic array.
For non-TPU work, the choice between Triton and Pallas is mostly stylistic — both are tile-DSLs over Triton’s lowering chain on GPU. For TPU work, Pallas is the only ergonomic option outside writing raw XLA/HLO.
Picking JAX vs PyTorch
Decision tree, brutally:
- TPU pod, frontier training, Google ecosystem → JAX.
- GPU-only, fast iteration, large existing PyTorch codebase → PyTorch + torch.compile.
- Research where you need the absolute lowest-overhead trace → JAX (the AOT-compile model has lower per-call overhead than Dynamo).
- Production serving with tight latency → PyTorch / vLLM (better OSS serving stack maturity in 2026).
Both are excellent. The skill of moving between them is increasingly common; the underlying compiler concepts (fusion, lowering, kernel escape hatches) are the same.
Run it in your browser — XLA-style HLO simulator
The shape — pointwise chains collapse into one fused op — is exactly what XLA’s Fusion pass produces in real HLO dumps.
Quick check
Key takeaways
- JAX → StableHLO → XLA → GPU/TPU/CPU. The mirror image of PyTorch’s Dynamo → Inductor → Triton.
- XLA fusion is conservative but powerful, especially on TPU. Inductor is greedier; the trade is GPU-mostly vs cross-platform.
- Pallas is JAX’s kernel escape hatch — Triton-shaped, with a TPU backend (Mosaic). Same source, two targets.
- JAX wins for TPU and frontier training inside Google ecosystem; PyTorch wins for fast iteration on GPU. Both are real careers in 2026.
- The compiler concepts transfer. Fusion, lowering, autotune are the same words in both stacks; only the spelling changes.
Go deeper
- DocsJAX — Pallas DocumentationUp-to-date. Section "Pallas Quickstart" + the matmul tutorial are the right starting point.
- DocsStableHLO — OpenXLAThe IR JAX lowers to. Section "StableHLO Specification" has the op-by-op semantics.
- DocsXLA DocumentationThe compiler. The "Architecture" section is the JAX-side analog of Inductor's fusion docs.
- PaperPallas: A Programming Model for AI AcceleratorsThe Pallas paper. Section 3 explains the GPU + TPU lowering split.
- BlogJAX Blog — Pallas: Bringing Triton-like Power to JAX/TPUPractitioner intro with worked examples.
- Repogoogle/jaxThe reference. `jax/_src/pallas/` for the Pallas implementation; `jax/experimental/mosaic/` for the TPU codegen.
- Repoopenxla/xlaThe XLA compiler. `xla/service/gpu/` for the GPU pipeline; `xla/service/cpu/` for CPU.
Prereqs: Triton, MLIR Overview, Operator Fusion. Pallas is “Triton, but for the JAX/XLA stack.”
TL;DR
- JAX uses a different compiler stack than PyTorch.
jit-decorated functions trace to StableHLO (a stable subset of HLO, the XLA op set), then XLA compiles to GPU/TPU. - XLA is the long-standing graph compiler (Google, ~2017+). It does fusion, layout assignment, code-gen for GPU/TPU/CPU. It’s the production compiler that runs Gemini, every Google production model, and JAX-based research everywhere.
- Pallas is JAX’s kernel DSL — Triton-like, but inside JAX. You write a JAX function, decorate it with
pallas.pallas_call, and it lowers to a hand-written kernel for GPU (via Triton) or TPU (via Mosaic, a TPU-specific kernel emitter). - The single most distinctive thing about Pallas: the same kernel source can target both GPU and TPU. The dtype/layout primitives are the same; the backend chooses how to lower.
- For 2026 production: JAX is the dominant choice for TPU work, neck-and-neck with PyTorch on GPU. Pallas is the kernel escape hatch — used heavily inside Google for Gemini training, increasingly used externally for high-perf JAX kernels.
Why this matters
The JAX/XLA stack is the other world of AI compilation. PyTorch dominates research-volume in 2026; JAX dominates TPU work and a meaningful slice of frontier training (Anthropic, Google, parts of OpenAI). Knowing how XLA differs from Inductor — and how Pallas differs from Triton — is the price of admission for working in any JAX-using lab.
Mental model
XLA is the umbrella; StableHLO is its IR; Pallas is the user-written-kernel escape hatch that hooks into the same lowering chain.
Concrete walkthrough
A JAX function and what XLA does to it
import jax
import jax.numpy as jnp
@jax.jit
def f(x, w, b):
return jax.nn.relu(x @ w + b)
# First call: trace + compile (slow). Subsequent calls: cached, fast.
out = f(jnp.ones((128, 256)), jnp.ones((256, 64)), jnp.zeros(64))What happens in order:
- JAX’s tracer runs
fwith abstract values, producing a jaxpr (JAX’s internal IR). - The jaxpr lowers to StableHLO —
dot,add,maximum(relu). - XLA receives the StableHLO module. Runs ~100 passes: fusion, layout assignment, autotune, codegen.
- Output: a binary kernel cached for this
(input_shapes, dtypes)signature. - Subsequent calls of
fwith matching signatures: invoke the cached kernel directly.
You can dump every stage:
print(jax.make_jaxpr(f)(x, w, b)) # jaxpr
print(jax.jit(f).lower(x, w, b).compile().as_text()) # compiled HLOThe HLO dump is the JAX equivalent of torch.compile’s output_code.py.
XLA fusion vs Inductor fusion
Both fuse pointwise ops. Differences:
| Inductor (PyTorch) | XLA (JAX/TF) | |
|---|---|---|
| Default kernel codegen | Triton | LLVM (CPU/GPU) + Mosaic (TPU) |
| Aggression | Greedy, often correct | More conservative, but very strong on TPU |
| Custom kernel | torch.library.custom_op + Triton | Pallas |
| Ahead-of-time compilation | Limited | First-class (XLA is AOT-by-default) |
| Cross-platform | GPU-mostly | GPU + TPU + CPU equally |
A subtle structural difference: XLA expects shapes at compile time (concrete or abstract via a tracing_shape). PyTorch dynamism is more flexible — a torch.compile’d function gracefully handles dynamic shapes via Dynamo’s guards. XLA pays for its perf on TPU with this rigidity.
Pallas — the kernel escape hatch
When XLA’s fusion doesn’t get you the performance you need, Pallas lets you write the kernel directly:
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
BM, BN, BK = 128, 128, 64 # tile sizes for the matmul
def matmul_kernel(x_ref, w_ref, o_ref):
"""Inner kernel: each program computes one (BM, BN) output tile by
looping over K in chunks of BK and accumulating."""
K = x_ref.shape[1]
acc = jnp.zeros((BM, BN), dtype=jnp.float32)
for k in range(0, K, BK):
x_tile = pl.load(x_ref, (slice(None), pl.ds(k, BK))) # (BM, BK)
w_tile = pl.load(w_ref, (pl.ds(k, BK), slice(None))) # (BK, BN)
acc += jnp.dot(x_tile, w_tile, preferred_element_type=jnp.float32)
o_ref[:, :] = acc.astype(o_ref.dtype)
@jax.jit
def matmul(x, w):
M, K = x.shape
_, N = w.shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((M, N), x.dtype),
grid=(M // BM, N // BN), # one program per (M, N) tile
in_specs=[pl.BlockSpec((BM, K), lambda i, j: (i, 0)), # full row strip of x
pl.BlockSpec((K, BN), lambda i, j: (0, j))], # full col strip of w
out_specs=pl.BlockSpec((BM, BN), lambda i, j: (i, j)),
)(x, w)The structure is exactly Triton’s: a kernel function with an explicit K-loop accumulator, a grid spec, block specs that describe slicing. The Python is JAX-flavored (uses jax.ShapeDtypeStruct), but the model — outer 2D grid over (M, N) tiles, inner sequential K-loop, FP32 accumulator — matches the canonical Triton GEMM line-for-line.
The gift: this same kernel runs on GPU (via Triton) and TPU (via Mosaic). On TPU you get tile sizes that map to the systolic array; on GPU you get tensor-core mma. Same source.
The JAX → TPU story
Pallas TPU support is the reason JAX is the production language inside Google. TPU has a fundamentally different programming model from GPU — wide systolic arrays, explicit MXU (matrix unit) tiles, no thread-level parallelism in the GPU sense. Pallas-on-TPU compiles via Mosaic, an MLIR-based TPU kernel emitter. The Pallas program describes a tile-level computation; Mosaic emits TPU-native code that schedules the systolic array.
For non-TPU work, the choice between Triton and Pallas is mostly stylistic — both are tile-DSLs over Triton’s lowering chain on GPU. For TPU work, Pallas is the only ergonomic option outside writing raw XLA/HLO.
Picking JAX vs PyTorch
Decision tree, brutally:
- TPU pod, frontier training, Google ecosystem → JAX.
- GPU-only, fast iteration, large existing PyTorch codebase → PyTorch + torch.compile.
- Research where you need the absolute lowest-overhead trace → JAX (the AOT-compile model has lower per-call overhead than Dynamo).
- Production serving with tight latency → PyTorch / vLLM (better OSS serving stack maturity in 2026).
Both are excellent. The skill of moving between them is increasingly common; the underlying compiler concepts (fusion, lowering, kernel escape hatches) are the same.
Run it in your browser — XLA-style HLO simulator
The shape — pointwise chains collapse into one fused op — is exactly what XLA’s Fusion pass produces in real HLO dumps.
Quick check
Key takeaways
- JAX → StableHLO → XLA → GPU/TPU/CPU. The mirror image of PyTorch’s Dynamo → Inductor → Triton.
- XLA fusion is conservative but powerful, especially on TPU. Inductor is greedier; the trade is GPU-mostly vs cross-platform.
- Pallas is JAX’s kernel escape hatch — Triton-shaped, with a TPU backend (Mosaic). Same source, two targets.
- JAX wins for TPU and frontier training inside Google ecosystem; PyTorch wins for fast iteration on GPU. Both are real careers in 2026.
- The compiler concepts transfer. Fusion, lowering, autotune are the same words in both stacks; only the spelling changes.
Go deeper
- DocsJAX — Pallas DocumentationUp-to-date. Section "Pallas Quickstart" + the matmul tutorial are the right starting point.
- DocsStableHLO — OpenXLAThe IR JAX lowers to. Section "StableHLO Specification" has the op-by-op semantics.
- DocsXLA DocumentationThe compiler. The "Architecture" section is the JAX-side analog of Inductor's fusion docs.
- PaperPallas: A Programming Model for AI AcceleratorsThe Pallas paper. Section 3 explains the GPU + TPU lowering split.
- BlogJAX Blog — Pallas: Bringing Triton-like Power to JAX/TPUPractitioner intro with worked examples.
- Repogoogle/jaxThe reference. `jax/_src/pallas/` for the Pallas implementation; `jax/experimental/mosaic/` for the TPU codegen.
- Repoopenxla/xlaThe XLA compiler. `xla/service/gpu/` for the GPU pipeline; `xla/service/cpu/` for CPU.