Build a Tensor Library
PyTorch and JAX feel like magic until you build one yourself. The mystery dissolves quickly: a working tensor library with autograd, five ops, and enough functionality to train MNIST is genuinely about 200 lines of Python. The load-bearing concepts that make these frameworks possible — strided tensors, the forward/backward pair, the topological reverse-walk we call “autograd” — are all present at this scale.
This lesson is the do it yourself version of the prior two. We’ll build four pieces: a buffer (the bytes), a tensor (buffer + walking instructions), an op (forward + backward), and an engine (the graph + the reverse traversal). You’ll write a complete autograd in 60 lines in your browser. After you’ve done it, loss.backward() stops being magic — you’ll picture exactly which buffers it allocates, which graph it walks, what requires_grad means at the implementation level.
The point isn’t to ship your own framework. It’s that once you’ve written the toy version, every later debugging session — including reading PyTorch internals or the Compilers track — becomes tractable. The abstraction is small; the leverage is large.
TL;DR
- A tensor library is smaller than people think: ~200 lines of Python gets you strided views, broadcasting, and a handful of ops good enough to train MNIST.
- The four types you need:
Buffer(raw memory),Tensor(buffer + shape + strides + offset),Op(forward + backward),Engine(the autograd graph + topological execution). - Once you’ve written one yourself, every “magic” in PyTorch / JAX becomes legible: shape errors, stride bugs, retain_graph weirdness, why
.detach()exists. - Real production libraries (PyTorch, JAX, tinygrad) add: dispatch to backends (CUDA, MPS, ROCm), op fusion, JIT compilation, distributed semantics. All of this is bolted onto the same core abstraction.
- The lesson companion is the module capstone — a 200-line tinygrad clone running MNIST.
Mental model
Four boxes. Build them in order, MNIST runs.
Step 1 — the buffer
A buffer is a 1D array. Use Python list for the toy version; numpy.ndarray for the not-toy version (we use a backing buffer for speed but the abstraction is the same).
import numpy as np
class Buffer:
"""1D contiguous storage. The actual bytes."""
def __init__(self, data):
self.data = np.asarray(data, dtype=np.float32).ravel()
def __len__(self): return len(self.data)
def __getitem__(self, i): return self.data[i]
def __setitem__(self, i, v): self.data[i] = vThat’s it for memory. The interesting structure is on top.
Step 2 — the strided tensor
Same as the Strides & Layout lesson, with a few utility methods:
class Tensor:
def __init__(self, buf, shape, strides=None, offset=0, requires_grad=False):
self.buf = buf if isinstance(buf, Buffer) else Buffer(buf)
self.shape = tuple(shape)
if strides is None:
# Default to contiguous row-major
s = []
cur = 1
for d in reversed(self.shape):
s.append(cur)
cur *= d
strides = tuple(reversed(s))
self.strides = tuple(strides)
self.offset = offset
self.requires_grad = requires_grad
self.grad = None
self._ctx = None # autograd: which Op produced me + inputs
def numel(self):
n = 1
for d in self.shape: n *= d
return n
def __getitem__(self, idx):
if isinstance(idx, int): idx = (idx,)
i = self.offset + sum(x * s for x, s in zip(idx, self.strides))
return self.buf[i]
# View-producing ops (cheap)
def transpose(self, a, b):
s = list(self.shape); st = list(self.strides)
s[a], s[b] = s[b], s[a]; st[a], st[b] = st[b], st[a]
return Tensor(self.buf, s, st, self.offset, self.requires_grad)Operations that change strides without copying are O(1). All the cheap manipulations from the previous lesson get one method here.
Step 3 — the op (forward + backward)
An op is a class with forward() and backward(). We attach _ctx to the output so backward can find what produced it.
class Op:
def __init__(self, *inputs):
self.inputs = inputs
self.requires_grad = any(t.requires_grad for t in inputs)
def forward(self):
raise NotImplementedError
def backward(self, grad):
raise NotImplementedError
def apply(self):
out = self.forward()
if self.requires_grad:
out._ctx = self # remember who produced me
out.requires_grad = True
return out
class Add(Op):
def forward(self):
a, b = self.inputs
# Assume contiguous + same shape for simplicity
out_data = a.buf.data + b.buf.data
return Tensor(Buffer(out_data), a.shape)
def backward(self, grad):
# d(a+b)/da = 1, d(a+b)/db = 1
return [grad, grad]
class MatMul(Op):
def forward(self):
a, b = self.inputs
out = np.matmul(a.buf.data.reshape(a.shape),
b.buf.data.reshape(b.shape))
return Tensor(Buffer(out), out.shape)
def backward(self, grad):
a, b = self.inputs
ga = np.matmul(grad.buf.data.reshape(grad.shape),
b.buf.data.reshape(b.shape).T)
gb = np.matmul(a.buf.data.reshape(a.shape).T,
grad.buf.data.reshape(grad.shape))
return [Tensor(Buffer(ga), a.shape), Tensor(Buffer(gb), b.shape)]
class ReLU(Op):
def forward(self):
x, = self.inputs
out = np.maximum(x.buf.data, 0)
return Tensor(Buffer(out), x.shape)
def backward(self, grad):
x, = self.inputs
mask = (x.buf.data > 0).astype(np.float32)
return [Tensor(Buffer(grad.buf.data * mask), grad.shape)]Five ops (add, matmul, relu, plus mul and softmax/cross-entropy) is enough for an MLP. Each is the math of forward + the math of backward.
Step 4 — the autograd engine
Topological sort of the op graph; walk it in reverse, accumulating gradients into .grad.
def backward(loss_tensor):
# 1. Build the topological order of nodes ending at loss_tensor.
visited = set()
topo = []
def dfs(node):
if id(node) in visited or node._ctx is None: return
visited.add(id(node))
for inp in node._ctx.inputs:
dfs(inp)
topo.append(node)
dfs(loss_tensor)
# 2. Seed the gradient at the loss with 1.
loss_tensor.grad = Tensor(Buffer(np.ones_like(loss_tensor.buf.data)), loss_tensor.shape)
# 3. Walk in reverse, calling each op's backward.
for node in reversed(topo):
op = node._ctx
grads = op.backward(node.grad)
for inp, g in zip(op.inputs, grads):
if not inp.requires_grad: continue
if inp.grad is None:
inp.grad = g
else:
inp.grad = Tensor(Buffer(inp.grad.buf.data + g.buf.data), inp.shape)That’s autograd. Topological order, reverse walk, accumulate. PyTorch’s loss.backward() is the same algorithm with thousands of ops in the registry instead of five.
Putting it together — train MNIST
# Two-layer MLP
W1 = Tensor(Buffer(np.random.randn(784, 128).astype(np.float32) * 0.02), (784, 128), requires_grad=True)
W2 = Tensor(Buffer(np.random.randn(128, 10).astype(np.float32) * 0.02), (128, 10), requires_grad=True)
for epoch in range(5):
for x, y in batches(mnist):
h = ReLU(MatMul(x, W1).apply()).apply()
logits = MatMul(h, W2).apply()
loss = CrossEntropy(logits, y).apply()
backward(loss)
# SGD step
for p in (W1, W2):
p.buf.data -= 0.01 * p.grad.buf.data
p.grad = NoneThis is a complete training loop. Real frameworks add: param groups, optimizer abstractions, learning-rate schedulers, mixed-precision casting, distributed all-reduce. None of that changes the four-box mental model — they’re just additional layers on top.
What you skip in the toy version
- Dispatch to backends. Real frameworks abstract over CPU, CUDA, MPS, etc. Each op needs a per-backend impl.
- Op fusion / JIT. The compilers track is exactly this story.
- Lazy evaluation. PyTorch is eager; tinygrad and JAX are lazy/traced. Lazy lets you fuse; eager is simpler to debug.
- Memory pooling. Real frameworks pre-allocate from a pool; the toy version trusts the OS allocator.
- Distributed. All-reduce, broadcast, gather. The training distributed module covers this.
What you keep:
- The autograd algorithm is the same.
- The strided tensor representation is the same.
- The op-as-class abstraction is the same.
That’s why writing the toy version is so productive — the load-bearing concepts are all present in 200 lines.
Run it in your browser — a tiny working tensor + autograd
The training loop converges in ~100 steps. You wrote a working ML framework. The next 199 lines (broadcasting, more ops, MNIST) are straightforward variations.
Quick check
Key takeaways
- A tensor library is small. ~200 lines for strided tensors + autograd + 5 ops + MNIST.
- Four boxes: Buffer, Tensor, Op (forward+backward), Engine (topo sort + reverse walk).
- Strides are the cost model, autograd is the math, ops are the API. Each is independent and learnable.
- Real frameworks scale this with dispatch, fusion, JIT, distributed. The core is the same.
- Build it once. PyTorch / JAX stop being magic; debugging them becomes tractable; the Compilers track immediately makes more sense.
Go deeper
- Videomicrograd — Andrej Karpathy~50-line autograd in Python from scratch. The clearest tutorial on the planet for the engine half of this lesson.
- Repokarpathy/microgradThe reference repo. Read in 30 minutes.
- Repotinygrad/tinygrad~5000 lines for a real framework with strides + 25 ops + multiple backends. The "scale up the toy" reference.
- BlogPyTorch Internals — Edward Z. YangHow the production version is structured. Read after writing your own toy.
- PaperAutomatic Differentiation in Machine Learning: A SurveyFoundational reference on forward-mode vs reverse-mode autograd. Section 3 frames why reverse-mode is the right choice for ML.
- DocsPyTorch — Extending AutogradHow to write a custom autograd Function in real PyTorch. The same forward+backward pattern, dispatched into the framework graph.
Prereqs: Strides & Layout, Contiguous vs Non-Contiguous. This lesson is the do it version of the prior two.
TL;DR
- A tensor library is smaller than people think: ~200 lines of Python gets you strided views, broadcasting, and a handful of ops good enough to train MNIST.
- The four types you need:
Buffer(raw memory),Tensor(buffer + shape + strides + offset),Op(forward + backward),Engine(the autograd graph + topological execution). - Once you’ve written one yourself, every “magic” in PyTorch / JAX becomes legible: shape errors, stride bugs, retain_graph weirdness, why
.detach()exists. - Real production libraries (PyTorch, JAX, tinygrad) add: dispatch to backends (CUDA, MPS, ROCm), op fusion, JIT compilation, distributed semantics. All of this is bolted onto the same core abstraction.
- The lesson companion is the module capstone — a 200-line tinygrad clone running MNIST.
Why this matters
Frameworks feel like magic until you build one. After you’ve written 200 lines of strided tensor + autograd, you’ll never again type loss.backward() without picturing exactly what it does — which buffers it allocates, which graph it walks, what requires_grad means at the implementation level. This understanding is what lets you debug PyTorch internals when something goes wrong, and it’s what makes the Compilers track readable, because every framework’s IR maps back to these fundamentals.
Mental model
Four boxes. Build them in order, MNIST runs.
Concrete walkthrough
Step 1 — the buffer
A buffer is a 1D array. Use Python list for the toy version; numpy.ndarray for the not-toy version (we use a backing buffer for speed but the abstraction is the same).
import numpy as np
class Buffer:
"""1D contiguous storage. The actual bytes."""
def __init__(self, data):
self.data = np.asarray(data, dtype=np.float32).ravel()
def __len__(self): return len(self.data)
def __getitem__(self, i): return self.data[i]
def __setitem__(self, i, v): self.data[i] = vThat’s it for memory. The interesting structure is on top.
Step 2 — the strided tensor
Same as the Strides & Layout lesson, with a few utility methods:
class Tensor:
def __init__(self, buf, shape, strides=None, offset=0, requires_grad=False):
self.buf = buf if isinstance(buf, Buffer) else Buffer(buf)
self.shape = tuple(shape)
if strides is None:
# Default to contiguous row-major
s = []
cur = 1
for d in reversed(self.shape):
s.append(cur)
cur *= d
strides = tuple(reversed(s))
self.strides = tuple(strides)
self.offset = offset
self.requires_grad = requires_grad
self.grad = None
self._ctx = None # autograd: which Op produced me + inputs
def numel(self):
n = 1
for d in self.shape: n *= d
return n
def __getitem__(self, idx):
if isinstance(idx, int): idx = (idx,)
i = self.offset + sum(x * s for x, s in zip(idx, self.strides))
return self.buf[i]
# View-producing ops (cheap)
def transpose(self, a, b):
s = list(self.shape); st = list(self.strides)
s[a], s[b] = s[b], s[a]; st[a], st[b] = st[b], st[a]
return Tensor(self.buf, s, st, self.offset, self.requires_grad)Operations that change strides without copying are O(1). All the cheap manipulations from the previous lesson get one method here.
Step 3 — the op (forward + backward)
An op is a class with forward() and backward(). We attach _ctx to the output so backward can find what produced it.
class Op:
def __init__(self, *inputs):
self.inputs = inputs
self.requires_grad = any(t.requires_grad for t in inputs)
def forward(self):
raise NotImplementedError
def backward(self, grad):
raise NotImplementedError
def apply(self):
out = self.forward()
if self.requires_grad:
out._ctx = self # remember who produced me
out.requires_grad = True
return out
class Add(Op):
def forward(self):
a, b = self.inputs
# Assume contiguous + same shape for simplicity
out_data = a.buf.data + b.buf.data
return Tensor(Buffer(out_data), a.shape)
def backward(self, grad):
# d(a+b)/da = 1, d(a+b)/db = 1
return [grad, grad]
class MatMul(Op):
def forward(self):
a, b = self.inputs
out = np.matmul(a.buf.data.reshape(a.shape),
b.buf.data.reshape(b.shape))
return Tensor(Buffer(out), out.shape)
def backward(self, grad):
a, b = self.inputs
ga = np.matmul(grad.buf.data.reshape(grad.shape),
b.buf.data.reshape(b.shape).T)
gb = np.matmul(a.buf.data.reshape(a.shape).T,
grad.buf.data.reshape(grad.shape))
return [Tensor(Buffer(ga), a.shape), Tensor(Buffer(gb), b.shape)]
class ReLU(Op):
def forward(self):
x, = self.inputs
out = np.maximum(x.buf.data, 0)
return Tensor(Buffer(out), x.shape)
def backward(self, grad):
x, = self.inputs
mask = (x.buf.data > 0).astype(np.float32)
return [Tensor(Buffer(grad.buf.data * mask), grad.shape)]Five ops (add, matmul, relu, plus mul and softmax/cross-entropy) is enough for an MLP. Each is the math of forward + the math of backward.
Step 4 — the autograd engine
Topological sort of the op graph; walk it in reverse, accumulating gradients into .grad.
def backward(loss_tensor):
# 1. Build the topological order of nodes ending at loss_tensor.
visited = set()
topo = []
def dfs(node):
if id(node) in visited or node._ctx is None: return
visited.add(id(node))
for inp in node._ctx.inputs:
dfs(inp)
topo.append(node)
dfs(loss_tensor)
# 2. Seed the gradient at the loss with 1.
loss_tensor.grad = Tensor(Buffer(np.ones_like(loss_tensor.buf.data)), loss_tensor.shape)
# 3. Walk in reverse, calling each op's backward.
for node in reversed(topo):
op = node._ctx
grads = op.backward(node.grad)
for inp, g in zip(op.inputs, grads):
if not inp.requires_grad: continue
if inp.grad is None:
inp.grad = g
else:
inp.grad = Tensor(Buffer(inp.grad.buf.data + g.buf.data), inp.shape)That’s autograd. Topological order, reverse walk, accumulate. PyTorch’s loss.backward() is the same algorithm with thousands of ops in the registry instead of five.
Putting it together — train MNIST
# Two-layer MLP
W1 = Tensor(Buffer(np.random.randn(784, 128).astype(np.float32) * 0.02), (784, 128), requires_grad=True)
W2 = Tensor(Buffer(np.random.randn(128, 10).astype(np.float32) * 0.02), (128, 10), requires_grad=True)
for epoch in range(5):
for x, y in batches(mnist):
h = ReLU(MatMul(x, W1).apply()).apply()
logits = MatMul(h, W2).apply()
loss = CrossEntropy(logits, y).apply()
backward(loss)
# SGD step
for p in (W1, W2):
p.buf.data -= 0.01 * p.grad.buf.data
p.grad = NoneThis is a complete training loop. Real frameworks add: param groups, optimizer abstractions, learning-rate schedulers, mixed-precision casting, distributed all-reduce. None of that changes the four-box mental model — they’re just additional layers on top.
What you skip in the toy version
- Dispatch to backends. Real frameworks abstract over CPU, CUDA, MPS, etc. Each op needs a per-backend impl.
- Op fusion / JIT. The compilers track is exactly this story.
- Lazy evaluation. PyTorch is eager; tinygrad and JAX are lazy/traced. Lazy lets you fuse; eager is simpler to debug.
- Memory pooling. Real frameworks pre-allocate from a pool; the toy version trusts the OS allocator.
- Distributed. All-reduce, broadcast, gather. The training distributed module covers this.
What you keep:
- The autograd algorithm is the same.
- The strided tensor representation is the same.
- The op-as-class abstraction is the same.
That’s why writing the toy version is so productive — the load-bearing concepts are all present in 200 lines.
Run it in your browser — a tiny working tensor + autograd
The training loop converges in ~100 steps. You wrote a working ML framework. The next 199 lines (broadcasting, more ops, MNIST) are straightforward variations.
Quick check
Key takeaways
- A tensor library is small. ~200 lines for strided tensors + autograd + 5 ops + MNIST.
- Four boxes: Buffer, Tensor, Op (forward+backward), Engine (topo sort + reverse walk).
- Strides are the cost model, autograd is the math, ops are the API. Each is independent and learnable.
- Real frameworks scale this with dispatch, fusion, JIT, distributed. The core is the same.
- Build it once. PyTorch / JAX stop being magic; debugging them becomes tractable; the Compilers track immediately makes more sense.
Go deeper
- Videomicrograd — Andrej Karpathy~50-line autograd in Python from scratch. The clearest tutorial on the planet for the engine half of this lesson.
- Repokarpathy/microgradThe reference repo. Read in 30 minutes.
- Repotinygrad/tinygrad~5000 lines for a real framework with strides + 25 ops + multiple backends. The "scale up the toy" reference.
- BlogPyTorch Internals — Edward Z. YangHow the production version is structured. Read after writing your own toy.
- PaperAutomatic Differentiation in Machine Learning: A SurveyFoundational reference on forward-mode vs reverse-mode autograd. Section 3 frames why reverse-mode is the right choice for ML.
- DocsPyTorch — Extending AutogradHow to write a custom autograd Function in real PyTorch. The same forward+backward pattern, dispatched into the framework graph.