PyTorch Caching Allocator
cudaMalloc is shockingly slow. Asking the GPU driver for a fresh chunk of memory costs about 50 microseconds — a glacial eternity in GPU time, where individual matrix multiplications complete in single-digit microseconds. If every PyTorch call to torch.empty(...) actually went to the driver, training a 7B model would spend more time waiting for memory than doing math.
This is the same problem the Stack vs Heap lesson set up: don’t touch the allocator on the hot path. The trick is to grab a giant region of memory once at startup and hand out little pieces of it ourselves. When a tensor is freed, we keep its memory in our own pool and reuse it for the next allocation.
That’s the PyTorch caching allocator. Steady-state allocation cost: ~1 microsecond. Same job, 50× faster. The same pattern shows up in every other production system that needs fast malloc: jemalloc, tcmalloc, slab allocators, . This lesson is the canonical example.
TL;DR
cudaMallocandcudaFreeare slow — ~50 μs and ~10 μs respectively. Calling them on every forward / backward step would dominate small-model training.- The PyTorch caching allocator is a userspace pool: it grabs large chunks from
cudaMallocupfront, hands out sub-allocations, and reuses them as tensors come and go.torch.empty(N)typically costs 0–1 μs in steady state. - The same pattern shows up everywhere production C++ goes fast: jemalloc, tcmalloc, , slab allocators. “Don’t call malloc on the hot path” is the universal performance discipline.
- The PyTorch allocator splits, merges, recycles blocks; tracks per-stream allocations to avoid hazards on async work; offers an “expandable segments” mode (2024+) that grows as needed.
- Read
torch.cuda.memory_summary()to see what the allocator is doing. Fragmentation, peak allocation, and reserved-vs-allocated are the metrics that matter.
The lifecycle of one tensor allocation
99% of allocations hit the cache. 1% (warmup, growth events, OOM-recovery) call into cudaMalloc.
x = torch.empty(1024, 1024, device='cuda') # 4 MB allocationBehind that one line:
- PyTorch’s allocator looks for a free block of >= 4 MB in its pool, on the right device + .
- Cache hit: returns the block. Cost: ~1 μs.
- Cache miss: allocator calls
cudaMallocfor a larger chunk (say 64 MB), splits off 4 MB, returns it. Cost: ~50 μs (the cudaMalloc). - When
xgoes out of scope: the block is not freed back to CUDA — it’s returned to the allocator’s pool, marked free, available for reuse.
The pool grows under demand and never shrinks (until process exit) by default. That’s actually fine for training: peak memory is allocated once, then reused forever.
Splitting and merging
The allocator maintains free blocks of various sizes. When a request comes in:
- Find the smallest free block >= the request.
- If it’s exactly the right size: take it.
- If it’s larger: split into (request) + (remainder); add the remainder back to the free list.
When a block is freed:
- Mark it free.
- Try to coalesce with adjacent free blocks (forming larger free regions).
This is buddy-allocator-like but with PyTorch-specific tweaks (alignment, stream-awareness, fragmentation heuristics).
Stream awareness — the GPU twist
GPUs are async. Two CUDA streams can run kernels concurrently. If thread A frees a tensor while thread B’s kernel is still using it (because B’s kernel hasn’t completed yet), you have a use-after-free that won’t blow up immediately — it will silently corrupt your results.
The caching allocator solves this by tagging each block with the stream that allocated it. A block freed on stream A can only be reused by an allocation on stream A — until A’s events have synchronized. Cross-stream reuse requires explicit record_stream() calls.
For most users this is invisible. For multi-stream code (custom kernels, overlapping data loading with compute), getting it wrong manifests as race conditions and memory corruption.
The two big environment knobs
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.8,max_split_size_mb:128expandable_segments: enables a 2024+ allocator mode where memory regions grow on demand instead of being fixed-size. Reduces fragmentation; usually a win for training.max_split_size_mb: don’t split blocks above this size. Prevents fragmenting a giant pool into many small pieces. Useful when you have one big tensor + many small ones.garbage_collection_threshold: when the allocator approaches OOM, force a GC pass freeing unused cached blocks back to CUDA before retrying.
These knobs matter when you hit “CUDA OOM but nvidia-smi shows free memory.” That’s almost always fragmentation — the allocator has the bytes but no contiguous block big enough.
Reading the memory summary
print(torch.cuda.memory_summary())Three numbers worth knowing:
- Allocated memory: bytes the allocator has handed out to live tensors.
- Reserved memory: bytes the allocator has grabbed from CUDA (≥ allocated; the difference is cached free blocks).
- Active vs Inactive blocks: allocated vs freed-but-cached.
Healthy training: allocated < reserved by a moderate margin (~10–30%); slowly-changing across steps. Unhealthy: reserved bloats while allocated stays flat → fragmentation.
Same pattern, other places
The general pattern — pool memory upfront, reuse — appears in:
- jemalloc / tcmalloc: replacement CPU allocators. Production C++ standard.
- Linux slab allocator: kernel object allocator.
- in Rust (
bumpalo), C++ (std::pmr::monotonic_buffer_resource), Go (sync.Pool). - Custom AI kernel pools (cuBLAS workspace, FlashAttention scratch, etc.).
If you’re writing performance-critical code, an arena or pool is usually somewhere in your design.
Custom allocators in C++
#include <memory_resource>
std::pmr::monotonic_buffer_resource pool(1 << 20); // 1 MB pool
std::pmr::vector<float> v{&pool}; // allocates from poolmonotonic_buffer_resource allocates from a pre-sized buffer and never frees individual blocks (it frees the whole pool at end). Perfect for “transient” data with bounded lifetime — e.g., a per-request workspace in a serving system.
For tensor-heavy workloads, the production move is something like PyTorch’s allocator: a free-list of chunks, refcounted blocks, stream-aware. The C++ standard library doesn’t ship a perfect equivalent; teams roll their own.
Run it in your browser — pool allocator simulator
You’ll typically see ~95–100% hit rate after a few iterations of the warmup loop — exactly how the PyTorch caching allocator behaves once your model’s memory pattern stabilizes.
Quick check
Key takeaways
cudaMallocis slow; the caching allocator amortizes it. Steady-state allocation is ~1 μs.- Pool, split, merge, recycle. Same pattern as jemalloc, slab allocators, arenas.
- Stream-aware to prevent use-after-free across async streams.
expandable_segmentsandmax_split_size_mbare the OOM-debugging knobs. Fragmentation is the usual failure mode.torch.cuda.memory_summary()is your friend. Check allocated vs reserved; high reserved/low allocated = fragmentation.
Go deeper
- DocsPyTorch — CUDA Memory ManagementAuthoritative. The "Caching Allocator" + "Memory Snapshot" sections.
- BlogPyTorch Blog — Understanding GPU MemoryTwo-part series with profiling and debugging recipes.
- DocsPyTorch — torch.cuda.memory APIsmemory_snapshot, memory_summary, memory_stats — the introspection toolkit.
- Blogtcmalloc OverviewCPU-side analogue. Reading helps cement the pool-allocator design pattern.
- RepoPyTorch — CUDACachingAllocatorThe source. ~3000 lines of C++; the most-read PyTorch internals file. Worth reading once.
- VideoPyTorch DevCon — Inside the CUDA Caching AllocatorPyTorch core devs walk through the implementation.
Prereqs: Stack vs Heap, Smart Pointers & RAII. This lesson is about the “you’re not actually calling
cudaMalloc” piece.
TL;DR
cudaMallocandcudaFreeare slow — ~50 μs and ~10 μs respectively. Calling them on every forward / backward step would dominate small-model training.- The PyTorch caching allocator is a userspace pool: it grabs large chunks from
cudaMallocupfront, hands out sub-allocations, and reuses them as tensors come and go.torch.empty(N)typically costs 0–1 μs in steady state. - The same pattern shows up everywhere production C++ goes fast: jemalloc, tcmalloc, arena allocators, slab allocators. “Don’t call malloc on the hot path” is the universal performance discipline.
- The PyTorch allocator splits, merges, recycles blocks; tracks per-stream allocations to avoid hazards on async work; offers an “expandable segments” mode (2024+) that grows as needed.
- Read
torch.cuda.memory_summary()to see what the allocator is doing. Fragmentation, peak allocation, and reserved-vs-allocated are the metrics that matter.
Why this matters
A 7B-model training step allocates and frees thousands of intermediate tensors. If each cudaMalloc took 50 μs, just allocations would consume hundreds of ms per step — more than the actual math. The caching allocator amortizes this to near zero. Knowing that it exists, what its failure modes are (fragmentation, OOM that’s not actually OOM), and how to read its stats is the first move when debugging memory in any modern ML system.
Mental model
99% of allocations hit the cache. 1% (warmup, growth events, OOM-recovery) call into cudaMalloc.
Concrete walkthrough
The lifecycle of one tensor allocation
x = torch.empty(1024, 1024, device='cuda') # 4 MB allocationWhat happens under the hood:
- PyTorch’s allocator looks for a free block of >= 4 MB in its pool, on the right device + stream.
- Cache hit: returns the block. Cost: ~1 μs.
- Cache miss: allocator calls
cudaMallocfor a larger chunk (say 64 MB), splits off 4 MB, returns it. Cost: ~50 μs (the cudaMalloc). - When
xgoes out of scope: the block is not freed — it’s returned to the allocator pool, marked free, available for reuse.
The pool grows under demand and never shrinks (until process exit) by default. This is fine for training: peak memory is allocated once, reused forever.
Splitting and merging
The allocator maintains free blocks of various sizes. When a request comes in:
- Find the smallest free block >= the request.
- If exactly the size: take it.
- If larger: split into (request) + (remainder); add remainder back to the free list.
When a block is freed:
- Mark it free.
- Try to coalesce with adjacent free blocks (forming larger free regions).
This is buddy-allocator-like but with PyTorch-specific tweaks (alignment, stream-awareness, fragmentation heuristics).
Stream awareness
GPUs are async. Two CUDA streams can run kernels concurrently. If thread A frees a tensor while thread B’s kernel is still using it (because the kernel hasn’t completed), you have a use-after-free.
The caching allocator solves this by tagging each block with its stream. A block freed on stream A can only be reused by an allocation on stream A — until A’s events have synchronized. Cross-stream reuse requires explicit record_stream() calls.
For most users this is invisible. For multi-stream code (custom kernels, overlapping data loading with compute), getting it wrong manifests as race conditions and memory corruption.
The two big environment knobs
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.8,max_split_size_mb:128expandable_segments: enables a 2024+ allocator mode where memory regions grow on demand instead of being fixed-size. Reduces fragmentation; usually a win for training. Default is being made True over time but check your version.max_split_size_mb: don’t split blocks above this size. Prevents fragmenting a giant pool into many small pieces. Useful when you have one big tensor + many small ones.garbage_collection_threshold: when allocator approaches OOM, force a GC pass freeing unused cached blocks back to CUDA before retrying.
These knobs matter when you hit “CUDA OOM but nvidia-smi shows free memory.” That’s almost always fragmentation — the allocator has the bytes but no contiguous block big enough.
Reading the memory summary
print(torch.cuda.memory_summary())Key sections:
- Allocated memory: bytes the allocator has handed out to live tensors.
- Reserved memory: bytes the allocator has grabbed from CUDA (≥ allocated; the difference is cached free blocks).
- Active vs Inactive blocks: allocated vs freed-but-cached.
- Allocations / Frees per stream: high values per step → heavy alloc churn.
Healthy training: allocated < reserved by a moderate margin (~10–30%); slowly-changing across steps. Unhealthy: reserved bloats while allocated stays flat → fragmentation.
Same pattern, other places
The general pattern — pool memory upfront, reuse — appears in:
- jemalloc / tcmalloc: replacement CPU allocators. Production C++ standard.
- Linux slab allocator: kernel object allocator.
- Arena allocators in Rust (
bumpalo), C++ (std::pmr::monotonic_buffer_resource), Go (sync.Pool). - Custom AI kernel pools (cuBLAS workspace, FlashAttention scratch, etc.).
If you write performance-critical code, an arena or pool is usually somewhere in your design.
Custom allocators in C++
#include <memory_resource>
std::pmr::monotonic_buffer_resource pool(1 << 20); // 1 MB pool
std::pmr::vector<float> v{&pool}; // allocates from poolmonotonic_buffer_resource allocates from a pre-sized buffer and never frees individual blocks (it frees the whole pool at end). Perfect for “transient” data with bounded lifetime — e.g., a per-request workspace in a serving system.
For tensor-heavy workloads, the production move is something like PyTorch’s allocator: a free-list of chunks, refcounted blocks, stream-aware. The C++ standard library doesn’t ship a perfect equivalent; teams roll their own.
Run it in your browser — pool allocator simulator
You’ll typically see ~95–100% hit rate after a few iterations of the warmup loop — exactly how the PyTorch caching allocator behaves once your model’s memory pattern stabilizes.
Quick check
Key takeaways
cudaMallocis slow; the caching allocator amortizes it. Steady-state allocation is ~1 μs.- Pool, split, merge, recycle. Same pattern as jemalloc, slab allocators, arenas.
- Stream-aware to prevent use-after-free across async streams.
expandable_segmentsandmax_split_size_mbare the OOM-debugging knobs. Fragmentation is the usual failure mode.torch.cuda.memory_summary()is your friend. Check allocated vs reserved; high reserved/low allocated = fragmentation.
Go deeper
- DocsPyTorch — CUDA Memory ManagementAuthoritative. The "Caching Allocator" + "Memory Snapshot" sections.
- BlogPyTorch Blog — Understanding GPU MemoryTwo-part series with profiling and debugging recipes.
- DocsPyTorch — torch.cuda.memory APIsmemory_snapshot, memory_summary, memory_stats — the introspection toolkit.
- Blogtcmalloc OverviewCPU-side analogue. Reading helps cement the pool-allocator design pattern.
- RepoPyTorch — CUDACachingAllocatorThe source. ~3000 lines of C++; the most-read PyTorch internals file. Worth reading once.
- VideoPyTorch DevCon — Inside the CUDA Caching AllocatorPyTorch core devs walk through the implementation.