Skip to content

Tensor Parallel

Prereqs: Data Parallel & DDP, Multi-Head Attention. TP shards the math inside a layer, not the batch.

DDP shards the batch — every GPU runs the same model on different data. FSDP shards the parameters — the model lives across ranks but each rank reconstructs full layers on the fly. shards the math inside each individual layer: the matmul itself runs across GPUs, with each GPU computing a slice of every weight matrix.

Megatron-LM (Shoeybi et al., 2019) is the canonical recipe. The trick is to compose two complementary slice patterns. The first matmul of a transformer block (the QKV projection in attention, the up-projection in MLP) is column-parallel — split along the output dim, each GPU gets a slice of the output, no communication needed. The second matmul (the output projection in attention, the down-projection in MLP) is row-parallel — split along the input dim, each GPU produces a partial sum of the output, and one sums them. Two AllReduces per transformer layer, one per major sub-block.

The Python is parallelize_module(layer, mesh, plan). PyTorch’s lets you spell out the sharding scheme; the framework inserts the right collective at the right place. What’s actually running underneath is NCCL Ring-AllReduce on the activations, twice per layer, every forward pass — and the bandwidth needed is so high that TP only works within an NVLink domain. Span TP across nodes and it collapses on InfiniBand. The universal recipe in 2026 is TP=8 within a node, FSDP across data-parallel ranks, for cross-node depth.

TL;DR

  • Tensor parallel (TP) splits each individual matmul (and embedding, attention) across multiple GPUs. Each GPU holds a slice of every weight matrix; activations get AllReduced after the slice.
  • Megatron-LM introduced the canonical TP design (Shoeybi et al., 2019): column-parallel for the first matmul of an MLP, row-parallel for the second, two AllReduces per transformer layer (one in attention, one in MLP).
  • TP shards weights and compute → memory and FLOPs scale roughly linearly with TP degree. The cost is communication: AllReduce on the activations every layer, twice.
  • TP works best within a single NVLink domain (one DGX node, ~8 GPUs). Across nodes, comm bandwidth tanks and TP scaling collapses. Production: TP=8 within node, DP / PP across nodes.
  • The 2026 default for training a 70B+ model: TP=8 + FSDP (DP / sharded) + PP=4. This is what TorchTitan, Megatron-Core, NeMo all assume.

Mental model

The trick: column-parallel input × row-parallel output means each GPU computes a partial sum, and the AllReduce sums them.

Column-parallel matmul

A linear layer y = x @ W where W is (in, out). Column-parallel: split W along the output dimension across G GPUs:

  • GPU g holds W[:, g*out/G : (g+1)*out/G], a slice of shape (in, out/G).
  • Each GPU computes y_g = x @ W[:, slice], output shape (B, T, out/G).
  • The concatenation of all y_g along the last dim equals the full output.

No comm needed for the forward output — each GPU has its slice. Comm comes in the backward, where the input gradient dx = dy @ W.T requires accumulating across GPUs.

Row-parallel matmul

The complement: split W along the input dimension:

  • GPU g holds W[g*in/G : (g+1)*in/G, :], shape (in/G, out).
  • Input x must already be sharded along its last dim (its in/G slice on each GPU).
  • Each GPU computes y_g = x_g @ W[slice, :], output (B, T, out) — but it’s a partial sum, not the final result.
  • AllReduce sums the partial outputs across GPUs → final y.

Row-parallel has the AllReduce in forward; column-parallel has it in backward.

The Megatron MLP block

The genius of Megatron’s design is to compose the two:

MLP(x) = ((x @ W_up) · gelu) @ W_down
  • W_up: (d, d_ffn). Column-parallel — each GPU holds (d, d_ffn/G). Output: (B, T, d_ffn/G) per GPU. No comm.
  • gelu: pointwise, no comm.
  • W_down: (d_ffn, d). Row-parallel — each GPU holds (d_ffn/G, d). Input is already sharded (output of column-parallel). Output: full (B, T, d), partial sum. AllReduce to finish.

One AllReduce per MLP block in forward. No intermediate AllReduces. The (B, T, d_ffn/G) intermediate stays sharded across the gelu without needing communication.

The Megatron attention block

Same composition for attention:

  • W_QKV: (d, 3d). Column-parallel — each GPU holds (d, 3d/G), gets its own Q,K,V slice. Naturally maps to n_heads / G heads per GPU.
  • Per-head attention: each GPU runs flash-attention on its local heads. No comm.
  • W_O: (d, d). Row-parallel — input is the per-GPU attention output, sharded across heads. Output is full (B, T, d), partial sum. AllReduce.

One AllReduce per attention block in forward. Combined with MLP: two AllReduces per layer (one per major sub-block).

For a 70B model with 80 layers, that’s 160 AllReduces per forward step. Each AllReduce moves (B × T × d × bytes) bytes per GPU. At B=4, T=8K, d=8192, BF16: each AllReduce is 4 × 8192 × 8192 × 2 = 512 MB / GPU. NVLink Gen 4 at 100 GB/s does that in 5 ms; over 160 layers that’s 800 ms of comm per step.

TP cost summary

For TP degree G on a single node:

ComponentMemory (per GPU)Comm
Model weightsparams / G × bytesnone
Activations(B × T × d) / G (sharded for some, replicated for others)depends
Attention forwardB × T × T × n_heads/G (with FA: O(N) effective)one AllReduce in W_O
MLP forwardB × T × d_ffn / Gone AllReduce in W_down
Per-layer total comm-2 AllReduces of B × T × d bytes

Per layer, per GPU bandwidth: 2 × 2(G-1)/G × B × T × d × bytes. For typical configs, this is ~100s of MB per layer; with 80 layers, the network is busy ~all the time.

Cross-node bandwidth is ~25–100 GB/s effective on InfiniBand (vs 100–600 GB/s NVLink). The same 512 MB AllReduce that takes 5 ms on NVLink takes 30+ ms across IB. Multiplied by 160 AllReduces per step, the comm dominates and your “8-way TP across two nodes” runs slower than 8-way TP within one node. Hence the rule: TP=8 within node, scale across nodes via PP / DP.

Composing with FSDP / PP

Production setups stack:

Mesh: 1024 GPUs - TP dim: 8 (within node, NVLink) - PP dim: 4 (across-node pipeline) - DP dim: 32 (replication, FSDP-sharded weights) 8 × 4 × 32 = 1024 ✓

PyTorch’s API (and TorchTitan, Megatron-Core, NeMo) lets you spell this out:

from torch.distributed.device_mesh import init_device_mesh mesh = init_device_mesh("cuda", (32, 4, 8), mesh_dim_names=("dp", "pp", "tp")) # Now you have a logical 3D grid; ops know which axes to communicate along.

The challenge of multi-dimensional parallelism is keeping the comm patterns separate — TP AllReduces happen on the TP axis only, FSDP AllGathers on the DP axis only. DTensor (PyTorch) and the equivalent in Megatron-Core handle this for you.

Run it in your browser — TP overhead simulator

Python — editableCompute per-layer comm time and total step time for various TP degrees + bandwidth scenarios.
Ctrl+Enter to run

You’ll see TP=8 on NVLink runs ~5× faster than no TP; the same TP=8 across IB runs slower than no TP because comm dominates. This single number is why every frontier training stack pins TP within node.

Quick check

Fill in the blank
The number of AllReduce communications per transformer layer in Megatron-style TP forward pass:
One per major sub-block.
Quick check
A team trains a 70B model on 16 GPUs split across 2 nodes connected by 200Gbps InfiniBand. They configure TP=16 (across both nodes). Throughput is way below expectations. Best fix:

Key takeaways

  1. TP shards each layer’s matmuls. Column-parallel input matmul, row-parallel output matmul, two AllReduces per layer.
  2. Megatron’s MLP design: W_up column-parallel, W_down row-parallel; one AllReduce per MLP. Same pattern for attention.
  3. TP within NVLink (TP=8 per node) is the universal recipe. Cross-node TP collapses on comm.
  4. TP, PP, DP/FSDP compose as orthogonal axes of a 3D / 4D / 5D mesh. PyTorch DeviceMesh is how you express it.
  5. Megatron-Core, TorchTitan, NeMo are the production references. Read their config files to see real-world TP × PP × DP shapes.

Go deeper

Prereqs: Data Parallel & DDP, Multi-Head Attention. TP shards the math inside a layer, not the batch.

TL;DR

  • Tensor parallel (TP) splits each individual matmul (and embedding, attention) across multiple GPUs. Each GPU holds a slice of every weight matrix; activations get AllReduced after the slice.
  • Megatron-LM introduced the canonical TP design (Shoeybi et al., 2019): column-parallel for the first matmul of an MLP, row-parallel for the second, two AllReduces per transformer layer (one in attention, one in MLP).
  • TP shards weights and compute → memory and FLOPs scale roughly linearly with TP degree. The cost is communication: AllReduce on the activations every layer, twice.
  • TP works best within a single NVLink domain (one DGX node, ~8 GPUs). Across nodes, comm bandwidth tanks and TP scaling collapses. Production: TP=8 within node, DP / PP across nodes.
  • The 2026 default for training a 70B+ model: TP=8 + FSDP (DP / sharded) + PP=4. This is what TorchTitan, Megatron-Core, NeMo all assume.

Why this matters

When the model doesn’t fit on a single GPU, TP is one of the two answers (the other is FSDP). For training, TP is preferred over FSDP for the largest tensors (attention QKV, FFN up/down) because the comm pattern is fixed-size and predictable. Knowing how TP shards each op, what the AllReduces cost, and when to combine TP with PP / FSDP is the foundation of any frontier training system. No engineer who can’t draw the Megatron diagram can talk credibly about training-systems architecture.

Mental model

The trick: column-parallel input × row-parallel output means each GPU computes a partial sum, and the AllReduce sums them.

Concrete walkthrough

Column-parallel matmul

A linear layer y = x @ W where W is (in, out). Column-parallel: split W along the output dimension across G GPUs:

  • GPU g holds W[:, g*out/G : (g+1)*out/G], a slice of shape (in, out/G).
  • Each GPU computes y_g = x @ W[:, slice], output shape (B, T, out/G).
  • The concatenation of all y_g along the last dim equals the full output.

No comm needed for the forward output — each GPU has its slice. Comm comes in the backward, where the input gradient dx = dy @ W.T requires accumulating across GPUs.

Row-parallel matmul

The complement: split W along the input dimension:

  • GPU g holds W[g*in/G : (g+1)*in/G, :], shape (in/G, out).
  • Input x must already be sharded along its last dim (its in/G slice on each GPU).
  • Each GPU computes y_g = x_g @ W[slice, :], output (B, T, out) — but it’s a partial sum, not the final result.
  • AllReduce sums the partial outputs across GPUs → final y.

Row-parallel has the AllReduce in forward; column-parallel has it in backward.

The Megatron MLP block

The genius of Megatron’s design is to compose the two:

MLP(x) = ((x @ W_up) · gelu) @ W_down
  • W_up: (d, d_ffn). Column-parallel — each GPU holds (d, d_ffn/G). Output: (B, T, d_ffn/G) per GPU. No comm.
  • gelu: pointwise, no comm.
  • W_down: (d_ffn, d). Row-parallel — each GPU holds (d_ffn/G, d). Input is already sharded (output of column-parallel). Output: full (B, T, d), partial sum. AllReduce to finish.

One AllReduce per MLP block in forward. No intermediate AllReduces. The (B, T, d_ffn/G) intermediate stays sharded across the gelu without needing communication.

The Megatron attention block

Same composition for attention:

  • W_QKV: (d, 3d). Column-parallel — each GPU holds (d, 3d/G), gets its own Q,K,V slice. Naturally maps to n_heads / G heads per GPU.
  • Per-head attention: each GPU runs flash-attention on its local heads. No comm.
  • W_O: (d, d). Row-parallel — input is the per-GPU attention output, sharded across heads. Output is full (B, T, d), partial sum. AllReduce.

One AllReduce per attention block in forward. Combined with MLP: two AllReduces per layer (one per major sub-block).

For a 70B model with 80 layers, that’s 160 AllReduces per forward step. Each AllReduce moves (B × T × d × bytes) bytes per GPU. At B=4, T=8K, d=8192, BF16: each AllReduce is 4 × 8192 × 8192 × 2 = 512 MB / GPU. NVLink Gen 4 at 100 GB/s does that in 5 ms; over 160 layers that’s 800 ms of comm per step.

TP cost summary

For TP degree G on a single node:

ComponentMemory (per GPU)Comm
Model weightsparams / G × bytesnone
Activations(B × T × d) / G (sharded for some, replicated for others)depends
Attention forwardB × T × T × n_heads/G (with FA: O(N) effective)one AllReduce in W_O
MLP forwardB × T × d_ffn / Gone AllReduce in W_down
Per-layer total comm-2 AllReduces of B × T × d bytes

Per layer, per GPU bandwidth: 2 × 2(G-1)/G × B × T × d × bytes. For typical configs, this is ~100s of MB per layer; with 80 layers, the network is busy ~all the time.

Cross-node bandwidth is ~25–100 GB/s effective on InfiniBand (vs 100–600 GB/s NVLink). The same 512 MB AllReduce that takes 5 ms on NVLink takes 30+ ms across IB. Multiplied by 160 AllReduces per step, the comm dominates and your “8-way TP across two nodes” runs slower than 8-way TP within one node. Hence the rule: TP=8 within node, scale across nodes via PP / DP.

Composing with FSDP / PP

Production setups stack:

Mesh: 1024 GPUs - TP dim: 8 (within node, NVLink) - PP dim: 4 (across-node pipeline) - DP dim: 32 (replication, FSDP-sharded weights) 8 × 4 × 32 = 1024 ✓

PyTorch’s DeviceMesh API (and TorchTitan, Megatron-Core, NeMo) lets you spell this out:

from torch.distributed.device_mesh import init_device_mesh mesh = init_device_mesh("cuda", (32, 4, 8), mesh_dim_names=("dp", "pp", "tp")) # Now you have a logical 3D grid; ops know which axes to communicate along.

The challenge of multi-dimensional parallelism is keeping the comm patterns separate — TP AllReduces happen on the TP axis only, FSDP AllGathers on the DP axis only. DTensor (PyTorch) and the equivalent in Megatron-Core handle this for you.

Run it in your browser — TP overhead simulator

Python — editableCompute per-layer comm time and total step time for various TP degrees + bandwidth scenarios.
Ctrl+Enter to run

You’ll see TP=8 on NVLink runs ~5× faster than no TP; the same TP=8 across IB runs slower than no TP because comm dominates. This single number is why every frontier training stack pins TP within node.

Quick check

Fill in the blank
The number of AllReduce communications per transformer layer in Megatron-style TP forward pass:
One per major sub-block.
Quick check
A team trains a 70B model on 16 GPUs split across 2 nodes connected by 200Gbps InfiniBand. They configure TP=16 (across both nodes). Throughput is way below expectations. Best fix:

Key takeaways

  1. TP shards each layer’s matmuls. Column-parallel input matmul, row-parallel output matmul, two AllReduces per layer.
  2. Megatron’s MLP design: W_up column-parallel, W_down row-parallel; one AllReduce per MLP. Same pattern for attention.
  3. TP within NVLink (TP=8 per node) is the universal recipe. Cross-node TP collapses on comm.
  4. TP, PP, DP/FSDP compose as orthogonal axes of a 3D / 4D / 5D mesh. PyTorch DeviceMesh is how you express it.
  5. Megatron-Core, TorchTitan, NeMo are the production references. Read their config files to see real-world TP × PP × DP shapes.

Go deeper