Skip to content

[NIXL] Add CacheLayout meta-tensor abstraction for descriptor generation#44362

Draft
ZhanqiuHu wants to merge 13 commits into
vllm-project:mainfrom
ZhanqiuHu:zhanqiu/cache-layout-on-tp-mapping
Draft

[NIXL] Add CacheLayout meta-tensor abstraction for descriptor generation#44362
ZhanqiuHu wants to merge 13 commits into
vllm-project:mainfrom
ZhanqiuHu:zhanqiu/cache-layout-on-tp-mapping

Conversation

@ZhanqiuHu

@ZhanqiuHu ZhanqiuHu commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Introduce CacheLayout, a meta-tensor-based abstraction that encodes KV cache memory layout (shape, strides, shard_axis) and generates NIXL descriptors without manual byte arithmetic.
  • Wire CacheLayout into _build_fa_local and _build_fa_remote, replacing manual if is_blocks_first / K-V split / TP head slicing logic with descriptors(), sub_block(), and narrow().
  • Depends on [NIXL][2/N] Cache TP slicing and mapping redesign #43151 (TPTransferSlice / get_tp_transfer_slices).

Key changes

layout.py (new file):

  • CacheLayout dataclass with meta (PyTorch meta tensor) and shard_axis
  • from_tensor() — local path, captures actual tensor's shape/strides
  • from_physical() — remote path, builds from raw shape/strides
  • narrow(), select(), sub_block(), descriptors() — TP slicing, block splitting, and descriptor generation
  • build_attn_layout() — factory for attention region layouts

worker.py (modified):

  • _region_layouts: list[CacheLayout] built in register_kv_caches (one per NIXL region, shard_axis=1 for Mamba, shard_axis=2 for blocks-first attention, shard_axis=1 for non-blocks-first)
  • _build_fa_local: iterates _region_layouts, calls sub_block() + descriptors()
  • _build_fa_remote: uses build_attn_layout() + TPTransferSlice.remote_read_offset via narrow() for TP head slicing, hoisted outside loop since block_lens are uniform for non-MLA

Test plan

  • Unit tests: 59/59 passed (test_nixl_connector.py)
  • lm_eval 4P1D Nemotron (mamba hybrid, hetero TP): accuracy 0.8461 (expected 0.84), KV hit rate 99.90%, 0 transfer errors — ALL CHECKS PASSED
  • Sweep: dense (Qwen3-0.6B), mamba (Nemotron), SWA (Gemma-3-4b-it), MLA (deepseek-vl2-tiny) across TP configs — running

ZhanqiuHu added 11 commits June 2, 2026 16:38
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Introduce CacheLayout, a meta-tensor-based abstraction that encodes
KV cache memory layout (shape, strides, shard_axis) and generates
NIXL descriptors without manual byte arithmetic.

- layout.py: CacheLayout with narrow(), sub_block(), descriptors(),
  from_tensor() (local path), from_physical() (remote path), and
  build_attn_layout() factory for attention regions.
- worker.py: Wire CacheLayout into _build_fa_local (uses stored
  _region_layouts) and _build_fa_remote (uses build_attn_layout +
  TPTransferSlice.narrow() for TP head slicing). Mamba regions get
  CacheLayout with shard_axis=1; remote Mamba path unchanged.

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
ZhanqiuHu added 2 commits June 3, 2026 15:00
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Replace the single `shard_axis` int and `virtually_split_kv` boolean
with two explicit tuple annotations:

- `iter_axes`: ordered dim indices that descriptors() iterates over
  (outer→inner), e.g. K/V split and blocks
- `shard_axes`: dim indices sliced for TP sharding

This makes the K/V split a proper axis in the layout shape rather
than a boolean flag, inspired by TVM's named axis types.

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
@ZhanqiuHu

Copy link
Copy Markdown
Contributor Author

CacheLayout v2 — named-axis design

Proposed next iteration: replace integer iter_axes/shard_axes with named dimensions using RFC #42082 vocabulary (B, H, N, C, KV, L). This eliminates index-shifting bookkeeping and makes all operations self-documenting (e.g. narrow('H', offset, heads) instead of narrow(layout.shard_axes[0], ...)).

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CacheLayout v2: named-axis layout for NIXL descriptor generation.

Dimensions use RFC #42082 vocabulary — B (blocks), H (heads),
N (block_size), C (channels), KV (K/V split), L (layers).

- **iter_axes**: ordered names ``descriptors()`` loops over (outer→inner).
- **shard_axes**: names sliced for TP sharding.
- Remaining dims are payload (contiguous byte span per descriptor).
"""

from __future__ import annotations

from dataclasses import dataclass
from math import prod

import torch


@dataclass(frozen=True)
class CacheLayout:
    meta: torch.Tensor            # shape/strides on meta device (zero memory)
    dim_names: tuple[str, ...]    # name per dim, e.g. ('B','H','N','C')
    iter_axes: tuple[str, ...]    # dims to iterate, outer→inner
    shard_axes: frozenset[str]    # dims sliced for TP

    def __post_init__(self) -> None:
        names = set(self.dim_names)
        assert len(self.dim_names) == self.meta.ndim
        assert len(names) == len(self.dim_names)  # no duplicates
        assert set(self.iter_axes) <= names and self.shard_axes <= names
        assert not (set(self.iter_axes) & self.shard_axes)  # no overlap
        # Non-iter dims must be C-contiguous for flat NIXL descriptors.
        iter_set = set(self.iter_axes)
        expected = 1  # stride accumulator, innermost-out
        for d in reversed(
            [i for i, n in enumerate(self.dim_names) if n not in iter_set]
        ):
            if self.meta.stride(d) != expected:
                raise ValueError(
                    f"dim '{self.dim_names[d]}' stride "
                    f"{self.meta.stride(d)} != {expected}")
            expected *= self.meta.shape[d]

    def _axis(self, name: str) -> int:
        """Resolve dim name → integer index."""
        return self.dim_names.index(name)

    @property
    def descriptor_size_bytes(self) -> int:
        """Product of non-iter dim sizes × element_size."""
        iter_set = set(self.iter_axes)
        return (prod(s for s, n in zip(self.meta.shape, self.dim_names)
                      if n not in iter_set) * self.meta.element_size())

    def narrow(self, name: str, start: int, length: int) -> CacheLayout:
        """Slice a dim (e.g. narrow('H', offset, local_heads) for TP)."""
        return CacheLayout(
            meta=self.meta.narrow(self._axis(name), start, length),
            dim_names=self.dim_names,
            iter_axes=self.iter_axes,
            shard_axes=self.shard_axes,
        )

    def select(self, name: str, index: int) -> CacheLayout:
        """Pick one index along an iter dim, collapsing it."""
        assert name in self.iter_axes
        ax = self._axis(name)
        return CacheLayout(
            meta=self.meta.select(ax, index),       # removes dim ax
            dim_names=self.dim_names[:ax] + self.dim_names[ax + 1:],
            iter_axes=tuple(n for n in self.iter_axes if n != name),
            shard_axes=self.shard_axes - {name},
        )

    def sub_block(self, ratio: int) -> CacheLayout:
        """View N blocks as N*ratio sub-blocks of 1/ratio block_size."""
        if ratio == 1:
            return self
        tagged = set(self.iter_axes) | self.shard_axes
        # First payload dim = block_size; innermost iter = blocks count
        bsz = self._axis(next(n for n in self.dim_names if n not in tagged))
        blk = self._axis(self.iter_axes[-1])  # blocks dim
        assert self.meta.shape[bsz] % ratio == 0
        assert self.meta.stride(blk) % ratio == 0
        shape = list(self.meta.shape)
        strides = list(self.meta.stride())
        shape[blk] *= ratio       # more blocks
        shape[bsz] //= ratio      # smaller block_size
        strides[blk] //= ratio    # tighter page stride
        return CacheLayout.from_physical(
            shape=tuple(shape),
            strides=tuple(strides),
            dtype=self.meta.dtype,
            dim_names=self.dim_names,
            iter_axes=self.iter_axes,
            shard_axes=self.shard_axes,
        )

    def descriptors(
        self, base_addr: int, device_id: int,
    ) -> list[tuple[int, int, int]]:
        """Generate (addr, size_bytes, device_id) per block.

        Recurses over iter_axes outer→inner; base case emits one descriptor.
        """
        if not self.iter_axes:
            # Leaf: all iter dims consumed, emit one flat descriptor
            off = int(self.meta.storage_offset()) * self.meta.element_size()
            return [(base_addr + off, self.descriptor_size_bytes, device_id)]
        name = self.iter_axes[0]  # outermost remaining iter dim
        out: list[tuple[int, int, int]] = []
        for i in range(self.meta.shape[self._axis(name)]):
            out.extend(
                self.select(name, i).descriptors(base_addr, device_id))
        return out

    @classmethod
    def from_tensor(
        cls,
        tensor: torch.Tensor,
        dim_names: tuple[str, ...],
        iter_axes: tuple[str, ...],
        shard_axes: frozenset[str],
    ) -> CacheLayout:
        """Build from a real GPU tensor (local path). Zero-copy metadata."""
        meta = torch.empty(
            0, dtype=tensor.dtype, device="meta",
        ).as_strided(tensor.shape, tensor.stride(), tensor.storage_offset())
        return cls(
            meta=meta,
            dim_names=dim_names,
            iter_axes=iter_axes,
            shard_axes=shard_axes,
        )

    @classmethod
    def from_physical(
        cls,
        shape: tuple[int, ...],
        strides: tuple[int, ...],
        dtype: torch.dtype,
        dim_names: tuple[str, ...],
        iter_axes: tuple[str, ...],
        shard_axes: frozenset[str],
        offset_bytes: int = 0,
    ) -> CacheLayout:
        """Build from raw shape/strides (remote path, no real tensor)."""
        elem = torch.empty(0, dtype=dtype).element_size()
        meta = torch.as_strided(
            torch.empty(1, dtype=dtype, device="meta"),
            size=shape, stride=strides,
            storage_offset=offset_bytes // elem,  # bytes → elements
        )
        return cls(
            meta=meta,
            dim_names=dim_names,
            iter_axes=iter_axes,
            shard_axes=shard_axes,
        )


def _c_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
    """Row-major strides in elements."""
    s, out = 1, []
    for d in reversed(shape):
        out.append(s)
        s *= d
    return tuple(reversed(out))


def build_attn_layout(
    num_blocks: int,
    num_kv_heads: int,
    head_size: int,
    block_size: int,
    dtype: torch.dtype,
    split_kv: bool = False,
    page_stride_bytes: int | None = None,
) -> CacheLayout:
    """Build CacheLayout for one attention layer region.

    split_kv=True  → (B, KV, H, N, C)  iter=('KV','B')  shard={'H'}
    split_kv=False → (B, H, N, C)       iter=('B',)      shard={'H'}
    """
    if split_kv:
        inner = (2, num_kv_heads, block_size, head_size)  # KV, H, N, C
        names: tuple[str, ...] = ('B', 'KV', 'H', 'N', 'C')
        iters: tuple[str, ...] = ('KV', 'B')  # K/V outer, blocks inner
    else:
        inner = (num_kv_heads, block_size, head_size)  # H, N, C
        names = ('B', 'H', 'N', 'C')
        iters = ('B',)  # blocks only

    shape = (num_blocks, *inner)
    inner_strides = _c_strides(inner)

    if page_stride_bytes is not None:
        elem = torch.empty(0, dtype=dtype).element_size()
        strides = (page_stride_bytes // elem, *inner_strides)  # custom page stride
    else:
        strides = _c_strides(shape)  # fully contiguous

    return CacheLayout.from_physical(
        shape=shape,
        strides=strides,
        dtype=dtype,
        dim_names=names,
        iter_axes=iters,
        shard_axes=frozenset({'H'}),
    )

Worker.py changes to wire in: just 2 lines — swap the import to layout_v2 and change remote_layout.narrow(remote_layout.shard_axes[0], ...)remote_layout.narrow('H', ...).

@mergify

mergify Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant