Skip to content

API reference

Auto-generated from the docstrings in src/scree/. If you find a discrepancy between this page and the code, the code is right — please file an issue.

The library is intentionally small: one data class, five core operations, six bridges, four reference kernels, three GPU Triton kernels.


Core type and operations — scree.*

scree.Array

A packed values+offsets array with one variable-length dimension.

Variable-length sequences stored as a flat values buffer plus offsets pointing at row boundaries.

Example

Three sequences of lengths [4, 2, 5], each with feature dim 8:

values: shape (11, 8)   # 4+2+5 along ragged_dim=0
offsets: [0, 4, 6, 11]  # length B+1
ragged_dim: 0

Construct with scree.pack([seq1, seq2, seq3]).

batch_size property

batch_size: int

lengths property

lengths: Any

Per-row lengths: lengths[i] = offsets[i+1] - offsets[i].

total_length property

total_length: int

dtype property

dtype: Any

feature_shape property

feature_shape: tuple

__post_init__

__post_init__() -> None
Source code in src/scree/_core.py
def __post_init__(self) -> None:
    if self.values.ndim < 1:
        raise ValueError(f"values must be at least 1-D, got {self.values.ndim}-D")
    if not (0 <= self.ragged_dim < self.values.ndim):
        raise ValueError(
            f"ragged_dim={self.ragged_dim} out of range for {self.values.ndim}-D values"
        )
    if self.offsets.ndim != 1:
        raise ValueError(f"offsets must be 1-D, got {self.offsets.ndim}-D")
    if len(self.offsets) < 2:
        raise ValueError("offsets must have length >= 2")
    if int(self.offsets[0]) != 0:
        raise ValueError(f"offsets[0] must be 0, got {int(self.offsets[0])}")
    ragged_size = self.values.shape[self.ragged_dim]
    if int(self.offsets[-1]) != ragged_size:
        raise ValueError(
            f"offsets[-1] ({int(self.offsets[-1])}) must equal "
            f"values.shape[ragged_dim={self.ragged_dim}] ({ragged_size})"
        )

scree.pack

Pack a list of arrays into a single scree.Array.

All arrays must share dtype and all non-ragged dims. The first array determines the backend (numpy or torch).

Source code in src/scree/_core.py
def pack(arrays: List[Any], ragged_dim: int = 0) -> Array:
    """Pack a list of arrays into a single scree.Array.

    All arrays must share dtype and all non-ragged dims. The first array
    determines the backend (numpy or torch).
    """
    if not arrays:
        raise ValueError("Cannot pack an empty list")
    first = arrays[0]
    lengths = [a.shape[ragged_dim] for a in arrays]

    if _is_torch(first):
        import torch

        offsets = torch.zeros(len(arrays) + 1, dtype=torch.int32, device=first.device)
        offsets[1:] = torch.tensor(lengths, dtype=torch.int32, device=first.device).cumsum(0)
        values = torch.cat(arrays, dim=ragged_dim)
    elif _is_mlx(first):
        import mlx.core as mx

        cumsum = 0
        offsets_list = [0]
        for length in lengths:
            cumsum += length
            offsets_list.append(cumsum)
        offsets = mx.array(offsets_list, dtype=mx.int32)
        values = mx.concatenate(arrays, axis=ragged_dim)
    elif _is_jax(first):
        import jax.numpy as jnp

        cumsum = 0
        offsets_list = [0]
        for length in lengths:
            cumsum += length
            offsets_list.append(cumsum)
        offsets = jnp.array(offsets_list, dtype=jnp.int32)
        values = jnp.concatenate(arrays, axis=ragged_dim)
    else:
        import numpy as np

        offsets = np.zeros(len(arrays) + 1, dtype=np.int32)
        offsets[1:] = np.cumsum(lengths)
        values = np.concatenate(arrays, axis=ragged_dim)

    return Array(values=values, offsets=offsets, ragged_dim=ragged_dim)

scree.unpack

Unpack a scree.Array into a list of arrays.

Returned slices are views into the original values where possible.

Source code in src/scree/_core.py
def unpack(arr: Array) -> List[Any]:
    """Unpack a scree.Array into a list of arrays.

    Returned slices are views into the original ``values`` where possible.
    """
    out: List[Any] = []
    rd = arr.ragged_dim
    ndim = arr.values.ndim
    for i in range(arr.batch_size):
        start = int(arr.offsets[i])
        end = int(arr.offsets[i + 1])
        slc = [slice(None)] * ndim
        slc[rd] = slice(start, end)
        out.append(arr.values[tuple(slc)])
    return out

scree.to_padded

Convert a scree.Array to a padded dense array + mask.

Returns (padded, mask) where: - padded.shape == (batch_size, max_len, *feature_dims) - mask.shape == (batch_size, max_len) — True for valid positions

Source code in src/scree/_core.py
def to_padded(arr: Array, side: str = "right", fill_value: float = 0.0) -> Tuple[Any, Any]:
    """Convert a scree.Array to a padded dense array + mask.

    Returns ``(padded, mask)`` where:
    - ``padded.shape == (batch_size, max_len, *feature_dims)``
    - ``mask.shape == (batch_size, max_len)`` — True for valid positions
    """
    if arr.ragged_dim != 0:
        raise NotImplementedError("to_padded supports ragged_dim=0 only in v0.1")
    if side not in ("right", "left"):
        raise ValueError(f"side must be 'right' or 'left', got {side!r}")

    lengths = [int(arr.offsets[i + 1] - arr.offsets[i]) for i in range(arr.batch_size)]
    max_len = max(lengths) if lengths else 0
    feature_shape = arr.values.shape[1:]
    batch = arr.batch_size

    if _is_mlx(arr.values):
        # MLX prefers a mutation-free construction (lazy graph).
        import mlx.core as mx

        rows_padded = []
        rows_mask = []
        for i, length in enumerate(lengths):
            start = int(arr.offsets[i])
            row = arr.values[start : start + length]
            pad_shape = (max_len - length, *feature_shape)
            pad = mx.full(pad_shape, fill_value, dtype=arr.values.dtype)
            valid_mask = mx.ones((length,), dtype=mx.bool_)
            pad_mask = mx.zeros((max_len - length,), dtype=mx.bool_)
            if side == "right":
                rows_padded.append(mx.concatenate([row, pad], axis=0))
                rows_mask.append(mx.concatenate([valid_mask, pad_mask], axis=0))
            else:
                rows_padded.append(mx.concatenate([pad, row], axis=0))
                rows_mask.append(mx.concatenate([pad_mask, valid_mask], axis=0))
        padded = mx.stack(rows_padded, axis=0)
        mask = mx.stack(rows_mask, axis=0)
        return padded, mask

    if _is_jax(arr.values):
        # JAX arrays are immutable — same mutation-free pattern as MLX.
        import jax.numpy as jnp

        rows_padded = []
        rows_mask = []
        for i, length in enumerate(lengths):
            start = int(arr.offsets[i])
            row = arr.values[start : start + length]
            pad_shape = (max_len - length, *feature_shape)
            pad = jnp.full(pad_shape, fill_value, dtype=arr.values.dtype)
            valid_mask = jnp.ones((length,), dtype=jnp.bool_)
            pad_mask = jnp.zeros((max_len - length,), dtype=jnp.bool_)
            if side == "right":
                rows_padded.append(jnp.concatenate([row, pad], axis=0))
                rows_mask.append(jnp.concatenate([valid_mask, pad_mask], axis=0))
            else:
                rows_padded.append(jnp.concatenate([pad, row], axis=0))
                rows_mask.append(jnp.concatenate([pad_mask, valid_mask], axis=0))
        padded = jnp.stack(rows_padded, axis=0)
        mask = jnp.stack(rows_mask, axis=0)
        return padded, mask

    if _is_torch(arr.values):
        import torch

        padded = torch.full(
            (batch, max_len, *feature_shape),
            fill_value,
            dtype=arr.values.dtype,
            device=arr.values.device,
        )
        mask = torch.zeros((batch, max_len), dtype=torch.bool, device=arr.values.device)
    else:
        import numpy as np

        padded = np.full((batch, max_len, *feature_shape), fill_value, dtype=arr.values.dtype)
        mask = np.zeros((batch, max_len), dtype=np.bool_)

    for i, length in enumerate(lengths):
        start = int(arr.offsets[i])
        row = arr.values[start : start + length]
        if side == "right":
            padded[i, :length] = row
            mask[i, :length] = True
        else:  # left
            padded[i, max_len - length :] = row
            mask[i, max_len - length :] = True

    return padded, mask

scree.from_padded

Convert (padded, mask) to a scree.Array.

Assumes right-padding (mask is True on the left side of each row).

Source code in src/scree/_core.py
def from_padded(padded: Any, mask: Any) -> Array:
    """Convert ``(padded, mask)`` to a scree.Array.

    Assumes right-padding (mask is True on the left side of each row).
    """
    batch = padded.shape[0]

    if _is_torch(padded):
        import torch

        lengths = [int(mask[i].sum().item()) for i in range(batch)]
        rows = [padded[i, : lengths[i]] for i in range(batch)]
        values = torch.cat(rows, dim=0)
        offsets = torch.zeros(batch + 1, dtype=torch.int32, device=padded.device)
        offsets[1:] = torch.tensor(lengths, dtype=torch.int32, device=padded.device).cumsum(0)
    elif _is_mlx(padded):
        import mlx.core as mx

        lengths = [int(mask[i].sum().item()) for i in range(batch)]
        rows = [padded[i, : lengths[i]] for i in range(batch)]
        values = mx.concatenate(rows, axis=0)
        cumsum = 0
        offsets_list = [0]
        for length in lengths:
            cumsum += length
            offsets_list.append(cumsum)
        offsets = mx.array(offsets_list, dtype=mx.int32)
    elif _is_jax(padded):
        import jax.numpy as jnp

        lengths = [int(mask[i].sum()) for i in range(batch)]
        rows = [padded[i, : lengths[i]] for i in range(batch)]
        values = jnp.concatenate(rows, axis=0)
        cumsum = 0
        offsets_list = [0]
        for length in lengths:
            cumsum += length
            offsets_list.append(cumsum)
        offsets = jnp.array(offsets_list, dtype=jnp.int32)
    else:
        import numpy as np

        lengths = [int(mask[i].sum()) for i in range(batch)]
        rows = [padded[i, : lengths[i]] for i in range(batch)]
        values = np.concatenate(rows, axis=0)
        offsets = np.zeros(batch + 1, dtype=np.int32)
        offsets[1:] = np.cumsum(lengths)

    return Array(values=values, offsets=offsets, ragged_dim=0)

scree.from_cu_seqlens

Construct a scree.Array from FlashAttention's cu_seqlens convention.

FlashAttention's cu_seqlens is exactly scree's offsets. Zero-copy.

Source code in src/scree/_core.py
def from_cu_seqlens(values: Any, cu_seqlens: Any) -> Array:
    """Construct a scree.Array from FlashAttention's cu_seqlens convention.

    FlashAttention's ``cu_seqlens`` is exactly scree's ``offsets``. Zero-copy.
    """
    return Array(values=values, offsets=cu_seqlens, ragged_dim=0)

Bridges — scree.bridges

Migration helpers between scree and existing ecosystem objects. Each bridge is zero-copy where the underlying memory layout allows.

to_torch_nested / from_torch_nested

Convert a scree.Array to a torch.NestedTensor (jagged layout).

The conversion materializes per-row views from the packed buffer and hands them to torch.nested.nested_tensor. Internally torch may share the underlying storage; we don't promise zero-copy.

Source code in src/scree/bridges/_torch_nested.py
def to_torch_nested(arr: Array) -> Any:
    """Convert a scree.Array to a torch.NestedTensor (jagged layout).

    The conversion materializes per-row views from the packed buffer
    and hands them to ``torch.nested.nested_tensor``. Internally torch
    may share the underlying storage; we don't promise zero-copy.
    """
    import torch

    if arr.ragged_dim != 0:
        raise NotImplementedError("to_torch_nested supports ragged_dim=0 only")
    rows = unpack(arr)
    return torch.nested.nested_tensor(list(rows), layout=torch.jagged)

Convert a torch.nested.NestedTensor (jagged) to a scree.Array.

Uses the jagged NestedTensor's underlying values + offsets directly (zero-copy when supported by the torch version).

Source code in src/scree/bridges/_torch_nested.py
def from_torch_nested(nt: Any) -> Array:
    """Convert a torch.nested.NestedTensor (jagged) to a scree.Array.

    Uses the jagged NestedTensor's underlying values + offsets directly
    (zero-copy when supported by the torch version).
    """
    if not getattr(nt, "is_nested", False):
        raise TypeError(f"expected a NestedTensor, got {type(nt).__name__}")
    import torch

    # torch.jagged NT exposes .values() and .offsets() on modern torch versions.
    # Fall back to unbinding + repacking if those aren't available.
    try:
        values = nt.values()
        offsets = nt.offsets().to(torch.int32)
    except (AttributeError, RuntimeError):
        rows = list(nt.unbind())
        values = torch.cat(rows, dim=0)
        lengths = [r.shape[0] for r in rows]
        offsets = torch.zeros(len(rows) + 1, dtype=torch.int32, device=values.device)
        offsets[1:] = torch.tensor(lengths, dtype=torch.int32, device=values.device).cumsum(0)

    return Array(values=values, offsets=offsets, ragged_dim=0)

to_hf_padded / from_hf_padded

Convert a scree.Array to HF (hidden_states, attention_mask).

Returns (hidden_states, attention_mask) where attention_mask is int64 with 1 for valid positions, 0 for padding (HF convention).

Source code in src/scree/bridges/_hf_padded.py
def to_hf_padded(arr: Array) -> Tuple[Any, Any]:
    """Convert a scree.Array to HF ``(hidden_states, attention_mask)``.

    Returns ``(hidden_states, attention_mask)`` where attention_mask is
    int64 with 1 for valid positions, 0 for padding (HF convention).
    """
    padded, mask = to_padded(arr, side="right")
    if _is_torch(mask):
        import torch

        mask = mask.to(dtype=torch.int64)
    else:
        import numpy as np

        mask = mask.astype(np.int64)
    return padded, mask

Convert HF (hidden_states, attention_mask) to a scree.Array.

Parameters:

Name Type Description Default
hidden_states (array - like, shape(batch, seq_len, *features))
required
attention_mask (array - like, shape(batch, seq_len))

1 for real tokens, 0 for padding (HF convention).

required
Source code in src/scree/bridges/_hf_padded.py
def from_hf_padded(hidden_states: Any, attention_mask: Any) -> Array:
    """Convert HF ``(hidden_states, attention_mask)`` to a scree.Array.

    Parameters
    ----------
    hidden_states : array-like, shape (batch, seq_len, *features)
    attention_mask : array-like, shape (batch, seq_len)
        1 for real tokens, 0 for padding (HF convention).
    """
    if _is_torch(attention_mask):
        mask = attention_mask.to(dtype=__import__("torch").bool)
    else:
        mask = attention_mask.astype(bool)
    return from_padded(hidden_states, mask)

to_torch / to_numpy

Re-export a scree.Array with its values/offsets as torch tensors.

Zero-copy on CPU via torch.from_numpy; zero-copy on GPU via DLPack.

Source code in src/scree/bridges/_dlpack.py
def to_torch(arr: Array) -> Array:
    """Re-export a scree.Array with its values/offsets as torch tensors.

    Zero-copy on CPU via ``torch.from_numpy``; zero-copy on GPU via DLPack.
    """
    import torch

    if _is_torch(arr.values):
        return arr

    # numpy -> torch
    values = torch.from_numpy(arr.values) if hasattr(arr.values, "__array__") else _via_dlpack(arr.values, torch)
    offsets = torch.from_numpy(arr.offsets) if hasattr(arr.offsets, "__array__") else _via_dlpack(arr.offsets, torch)
    return Array(values=values, offsets=offsets.to(torch.int32), ragged_dim=arr.ragged_dim)

Re-export a scree.Array with its values/offsets as numpy arrays.

Zero-copy from CPU torch tensors; for GPU torch tensors, copies to host.

Source code in src/scree/bridges/_dlpack.py
def to_numpy(arr: Array) -> Array:
    """Re-export a scree.Array with its values/offsets as numpy arrays.

    Zero-copy from CPU torch tensors; for GPU torch tensors, copies to host.
    """
    import numpy as np

    if not _is_torch(arr.values):
        return arr

    # torch -> numpy
    values = arr.values.detach().cpu().numpy()
    offsets = arr.offsets.detach().cpu().numpy()
    return Array(values=values, offsets=offsets.astype(np.int32), ragged_dim=arr.ragged_dim)

Reference kernels — scree.kernels.reference

Pure-Python (or PyTorch / MLX / JAX) implementations of the four varlen kernels. Used as ground truth in CI tests of the optimized Triton kernels. Not for production speed — they iterate Python-level over sequences.

varlen_attention

Reference (slow but correct) implementation of varlen self-attention.

varlen_attention

varlen_attention(q: Array, k: Array, v: Array, causal: bool = False) -> Array

Variable-length self-attention.

Each sequence in the batch attends only to itself — no cross-sequence attention. This is the operation that powers FlashAttention-varlen and the packed inference path of vLLM/SGLang; here we ship the obviously correct slow reference for use as a ground truth in CI.

Parameters:

Name Type Description Default
q Array

Each with shape (total_len, n_heads, head_dim) and matching offsets.

required
k Array

Each with shape (total_len, n_heads, head_dim) and matching offsets.

required
v Array

Each with shape (total_len, n_heads, head_dim) and matching offsets.

required
causal bool

If True, apply a lower-triangular mask within each sequence.

False

Returns:

Type Description
Array

Same offsets as q.

Source code in src/scree/kernels/reference/varlen_attention.py
def varlen_attention(q: Array, k: Array, v: Array, causal: bool = False) -> Array:
    """Variable-length self-attention.

    Each sequence in the batch attends only to itself — no cross-sequence
    attention. This is the operation that powers FlashAttention-varlen and
    the packed inference path of vLLM/SGLang; here we ship the obviously
    correct slow reference for use as a ground truth in CI.

    Parameters
    ----------
    q, k, v : scree.Array
        Each with shape ``(total_len, n_heads, head_dim)`` and matching
        ``offsets``.
    causal : bool
        If True, apply a lower-triangular mask within each sequence.

    Returns
    -------
    scree.Array
        Same offsets as ``q``.
    """
    head_dim = q.values.shape[-1]
    scale = 1.0 / math.sqrt(head_dim)

    if _is_jax(q.values):
        import jax.nn as jnn
        import jax.numpy as jnp

        if not jnp.array_equal(q.offsets, k.offsets) or not jnp.array_equal(
            q.offsets, v.offsets
        ):
            raise ValueError("q, k, v must have identical offsets")

        out_rows = []
        for i in range(q.batch_size):
            s = int(q.offsets[i])
            e = int(q.offsets[i + 1])
            qi = q.values[s:e]
            ki = k.values[s:e]
            vi = v.values[s:e]
            scores = jnp.einsum("ihd,jhd->hij", qi, ki) * scale
            if causal:
                length = qi.shape[0]
                cmask = jnp.triu(jnp.ones((length, length), dtype=bool), k=1)
                scores = jnp.where(cmask, -jnp.inf, scores)
            attn = jnn.softmax(scores, axis=-1)
            out_i = jnp.einsum("hij,jhd->ihd", attn, vi)
            out_rows.append(out_i)
        values = jnp.concatenate(out_rows, axis=0)
        return Array(values=values, offsets=q.offsets, ragged_dim=0)

    if _is_mlx(q.values):
        import mlx.core as mx

        if not mx.array_equal(q.offsets, k.offsets).item() or not mx.array_equal(
            q.offsets, v.offsets
        ).item():
            raise ValueError("q, k, v must have identical offsets")

        out_rows = []
        for i in range(q.batch_size):
            s = int(q.offsets[i])
            e = int(q.offsets[i + 1])
            qi = q.values[s:e]  # (Li, H, D)
            ki = k.values[s:e]
            vi = v.values[s:e]
            scores = mx.einsum("ihd,jhd->hij", qi, ki) * scale
            if causal:
                length = qi.shape[0]
                # Build causal mask: True on upper triangle (positions to mask)
                row_idx = mx.arange(length).reshape((length, 1))
                col_idx = mx.arange(length).reshape((1, length))
                cmask = col_idx > row_idx
                scores = mx.where(cmask, mx.array(-mx.inf, dtype=scores.dtype), scores)
            attn = mx.softmax(scores, axis=-1)
            out_i = mx.einsum("hij,jhd->ihd", attn, vi)
            out_rows.append(out_i)
        values = mx.concatenate(out_rows, axis=0)
        return Array(values=values, offsets=q.offsets, ragged_dim=0)

    if _is_torch(q.values):
        import torch

        if not torch.equal(q.offsets, k.offsets) or not torch.equal(q.offsets, v.offsets):
            raise ValueError("q, k, v must have identical offsets")

        out_rows = []
        for i in range(q.batch_size):
            s = int(q.offsets[i])
            e = int(q.offsets[i + 1])
            qi = q.values[s:e]  # (Li, H, D)
            ki = k.values[s:e]
            vi = v.values[s:e]
            scores = torch.einsum("ihd,jhd->hij", qi, ki) * scale  # (H, Li, Li)
            if causal:
                length = qi.shape[0]
                mask = torch.triu(
                    torch.ones(length, length, device=qi.device, dtype=torch.bool),
                    diagonal=1,
                )
                scores = scores.masked_fill(mask, float("-inf"))
            attn = torch.softmax(scores, dim=-1)
            out_i = torch.einsum("hij,jhd->ihd", attn, vi)
            out_rows.append(out_i)
        values = torch.cat(out_rows, dim=0)
    else:
        import numpy as np

        if not np.array_equal(q.offsets, k.offsets) or not np.array_equal(q.offsets, v.offsets):
            raise ValueError("q, k, v must have identical offsets")

        out_rows = []
        for i in range(q.batch_size):
            s = int(q.offsets[i])
            e = int(q.offsets[i + 1])
            qi = q.values[s:e]
            ki = k.values[s:e]
            vi = v.values[s:e]
            scores = np.einsum("ihd,jhd->hij", qi, ki) * scale
            if causal:
                length = qi.shape[0]
                mask = np.triu(np.ones((length, length), dtype=bool), k=1)
                scores = np.where(mask, -np.inf, scores)
            scores_max = scores.max(axis=-1, keepdims=True)
            scores_exp = np.exp(scores - scores_max)
            attn = scores_exp / scores_exp.sum(axis=-1, keepdims=True)
            out_i = np.einsum("hij,jhd->ihd", attn, vi)
            out_rows.append(out_i)
        values = np.concatenate(out_rows, axis=0)

    return Array(values=values, offsets=q.offsets, ragged_dim=0)

varlen_layernorm

Reference (slow but correct) implementation of varlen layernorm.

Layernorm is per-token, so for variable-length data it's just elementwise normalization over the last (feature) dim. No cross-row interaction — the only reason this needs a varlen implementation is to operate directly on a packed scree.Array without unpacking.

varlen_layernorm

varlen_layernorm(arr: Array, weight: object | None = None, bias: object | None = None, eps: float = 1e-05) -> Array

LayerNorm over the last dim of a packed scree.Array.

Parameters:

Name Type Description Default
arr Array

Packed values of shape (total_len, ..., feature_dim).

required
weight optional

Scale and shift parameters of shape (feature_dim,).

None
bias optional

Scale and shift parameters of shape (feature_dim,).

None
eps float

Numerical stability epsilon.

1e-05
Source code in src/scree/kernels/reference/varlen_layernorm.py
def varlen_layernorm(
    arr: Array,
    weight: object | None = None,
    bias: object | None = None,
    eps: float = 1e-5,
) -> Array:
    """LayerNorm over the last dim of a packed scree.Array.

    Parameters
    ----------
    arr : scree.Array
        Packed values of shape ``(total_len, ..., feature_dim)``.
    weight, bias : optional
        Scale and shift parameters of shape ``(feature_dim,)``.
    eps : float
        Numerical stability epsilon.
    """
    if _is_mlx(arr.values):
        import mlx.core as mx

        x = arr.values
        mean = mx.mean(x, axis=-1, keepdims=True)
        var = mx.var(x, axis=-1, keepdims=True)
        y = (x - mean) / mx.sqrt(var + eps)
        if weight is not None:
            y = y * weight
        if bias is not None:
            y = y + bias
    elif _is_jax(arr.values):
        import jax.numpy as jnp

        x = arr.values
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        y = (x - mean) / jnp.sqrt(var + eps)
        if weight is not None:
            y = y * weight
        if bias is not None:
            y = y + bias
    elif _is_torch(arr.values):
        import torch

        x = arr.values
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        y = (x - mean) / torch.sqrt(var + eps)
        if weight is not None:
            y = y * weight
        if bias is not None:
            y = y + bias
    else:
        import numpy as np

        x = arr.values
        mean = x.mean(axis=-1, keepdims=True)
        var = x.var(axis=-1, keepdims=True)
        y = (x - mean) / np.sqrt(var + eps)
        if weight is not None:
            y = y * weight
        if bias is not None:
            y = y + bias

    return Array(values=y, offsets=arr.offsets, ragged_dim=arr.ragged_dim)

varlen_rmsnorm

Reference (slow but correct) implementation of varlen RMSNorm.

RMSNorm (Zhang & Sennrich, 2019) drops the mean-subtraction step from LayerNorm — it normalizes by the root-mean-square only. It is the norm used by LLaMA, Mistral, Mixtral, DeepSeek, Qwen, and most modern open transformers, replacing LayerNorm in nearly every architecture released since 2023.

Like LayerNorm, RMSNorm is per-token (no cross-row interaction), so for variable-length data it's just elementwise on the packed buffer.

varlen_rmsnorm

varlen_rmsnorm(arr: Array, weight: object | None = None, eps: float = 1e-06) -> Array

RMSNorm over the last dim of a packed scree.Array.

Parameters:

Name Type Description Default
arr Array

Packed values of shape (total_len, ..., feature_dim).

required
weight optional

Scale parameter of shape (feature_dim,).

None
eps float

Numerical stability epsilon (typical: 1e-6 for LLaMA-family).

1e-06
Source code in src/scree/kernels/reference/varlen_rmsnorm.py
def varlen_rmsnorm(
    arr: Array,
    weight: object | None = None,
    eps: float = 1e-6,
) -> Array:
    """RMSNorm over the last dim of a packed scree.Array.

    Parameters
    ----------
    arr : scree.Array
        Packed values of shape ``(total_len, ..., feature_dim)``.
    weight : optional
        Scale parameter of shape ``(feature_dim,)``.
    eps : float
        Numerical stability epsilon (typical: 1e-6 for LLaMA-family).
    """
    if _is_mlx(arr.values):
        import mlx.core as mx

        x = arr.values
        rms = mx.sqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)
        y = x / rms
        if weight is not None:
            y = y * weight
    elif _is_jax(arr.values):
        import jax.numpy as jnp

        x = arr.values
        rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + eps)
        y = x / rms
        if weight is not None:
            y = y * weight
    elif _is_torch(arr.values):
        import torch

        x = arr.values
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
        y = x / rms
        if weight is not None:
            y = y * weight
    else:
        import numpy as np

        x = arr.values
        rms = np.sqrt((x * x).mean(axis=-1, keepdims=True) + eps)
        y = x / rms
        if weight is not None:
            y = y * weight

    return Array(values=y, offsets=arr.offsets, ragged_dim=arr.ragged_dim)

varlen_softmax

Reference (slow but correct) implementation of varlen softmax.

Softmax along the ragged dimension. Unlike layernorm, this is non-trivial for packed data because softmax must be computed within each sequence separately — not across the full concatenated buffer.

varlen_softmax

varlen_softmax(arr: Array) -> Array

Softmax along the ragged dimension, per-sequence.

Each row (sequence) is softmaxed independently. The output has the same shape and offsets as the input.

Source code in src/scree/kernels/reference/varlen_softmax.py
def varlen_softmax(arr: Array) -> Array:
    """Softmax along the ragged dimension, per-sequence.

    Each row (sequence) is softmaxed independently. The output has the
    same shape and offsets as the input.
    """
    if arr.ragged_dim != 0:
        raise NotImplementedError("varlen_softmax supports ragged_dim=0 only in v0.1")

    if _is_mlx(arr.values):
        import mlx.core as mx

        out_rows = []
        for i in range(arr.batch_size):
            s = int(arr.offsets[i])
            e = int(arr.offsets[i + 1])
            out_rows.append(mx.softmax(arr.values[s:e], axis=0))
        values = mx.concatenate(out_rows, axis=0)
    elif _is_jax(arr.values):
        import jax.nn as jnn
        import jax.numpy as jnp

        out_rows = []
        for i in range(arr.batch_size):
            s = int(arr.offsets[i])
            e = int(arr.offsets[i + 1])
            out_rows.append(jnn.softmax(arr.values[s:e], axis=0))
        values = jnp.concatenate(out_rows, axis=0)
    elif _is_torch(arr.values):
        import torch

        out_rows = []
        for i in range(arr.batch_size):
            s = int(arr.offsets[i])
            e = int(arr.offsets[i + 1])
            out_rows.append(torch.softmax(arr.values[s:e], dim=0))
        values = torch.cat(out_rows, dim=0)
    else:
        import numpy as np

        out_rows = []
        for i in range(arr.batch_size):
            s = int(arr.offsets[i])
            e = int(arr.offsets[i + 1])
            row = arr.values[s:e]
            row_max = row.max(axis=0, keepdims=True)
            row_exp = np.exp(row - row_max)
            out_rows.append(row_exp / row_exp.sum(axis=0, keepdims=True))
        values = np.concatenate(out_rows, axis=0)

    return Array(values=values, offsets=arr.offsets, ragged_dim=0)

Triton kernels — scree.kernels.triton

CUDA-only. Importing scree.kernels.triton is safe on non-CUDA platforms (TRITON_AVAILABLE is False and no kernel symbols are exported), but calling the kernels without CUDA raises an informative error.

varlen_attention_triton

The forward kernel — 1.30× of FA-2 on H100 for the headline workload.

varlen_attention_triton_autograd

Autograd-aware wrapper. Forward + backward both run on Triton kernels (FA-2 style: preprocess + dKV + dQ). Use this when you need gradients to flow through q, k, v. Full training step at 1.61× of FA-2.

varlen_rmsnorm_triton

13.97× speedup vs PyTorch reference on H100 (no native RMSNorm in PyTorch).

varlen_layernorm_triton

1.31× speedup vs torch.nn.functional.layer_norm on H100.


What's NOT in the public API

Names with a leading underscore in any module are private and subject to change without notice. In particular:

  • scree._core._is_torch, _is_mlx, _is_jax — backend dispatch predicates
  • scree.kernels.triton._varlen_attn_fwd_kernel — the raw Triton kernel
  • scree.kernels.triton._varlen_attn_bwd_*_kernel — raw backward kernels
  • scree.kernels.triton._backward.varlen_attention_triton_backward — the host-side backward orchestrator (used by the autograd wrapper)