Skip to content

[NIXL][2/N] Cache TP slicing and mapping redesign#43151

Open
ZhanqiuHu wants to merge 10 commits into
vllm-project:mainfrom
ZhanqiuHu:tp-mapping-redesign
Open

[NIXL][2/N] Cache TP slicing and mapping redesign#43151
ZhanqiuHu wants to merge 10 commits into
vllm-project:mainfrom
ZhanqiuHu:tp-mapping-redesign

Conversation

@ZhanqiuHu

@ZhanqiuHu ZhanqiuHu commented May 19, 2026

Copy link
Copy Markdown
Contributor

RFC: #42082

@mergify mergify Bot added deepseek Related to DeepSeek models v1 kv-connector labels May 19, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the tp-mapping-redesign branch from 1ea360e to f16f573 Compare May 19, 2026 21:49

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Tensor Parallel (TP) mapping logic for NIXL KV cache transfers by introducing a more robust TPTransferSlice structure and moving the mapping computation into the KVCacheSpec classes. It also updates various attention layers to provide total KV head counts and refactors the worker and test suites to accommodate these changes. Feedback highlights critical issues in the integration test script regarding hardcoded GPU arrays, unused variables in the worker, and several algorithmic flaws in the TP slicing logic, including non-robust partitioning and potential bugs in GQA replication.

I am having trouble creating individual review comments. Click here to see my feedback.

tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh (83-84)

high

The hardcoded PREFILL_GPUS and DECODE_GPUS arrays are too small to support the TP sizes mentioned in the PR description (e.g., TP=4). For instance, if PREFILLER_TP_SIZE=4, the script will attempt to access PREFILL_GPUS[3], which is out of bounds for (1 2 4). This will cause the test script to crash or produce invalid CUDA_VISIBLE_DEVICES strings.

PREFILL_GPUS=(0 1 2 3)
DECODE_GPUS=(4 5 6 7)

vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py (1181)

high

The variable remote_heads_per_rank is calculated using a non-robust division (//) and is never used in the logic (only in a log message). It should be removed or fixed to use robust partitioning ((i+1)*total//size - i*total//size) if it's intended for future use.

vllm/v1/kv_cache_interface.py (256-311)

high

The TP slicing logic has several issues:

  1. Non-robust partitioning: Using total // size fails when total is not divisible by size. It should use the difference between offsets: (i+1)*total//size - i*total//size.
  2. GQA Replication Bug: In the P_TP > D_TP branch, if total < remote_tp_size, some remote ranks will have 0 heads. The current logic might pick a rank with 0 heads and skip subsequent ranks that actually contain the head data because they share the same head index (calculated as 0). It should skip ranks where remote_start == remote_end.
  3. Unused Variables: remote_heads (line 260) and local_heads (line 285) are calculated but not used in their respective branches.
  4. Non-multiple TP sizes: The abs_tp logic assumes remote_tp_size is a multiple of local_tp_size. Using robust range boundaries start = local_tp_rank * remote_tp_size // local_tp_size is safer.
        if local_tp_size >= remote_tp_size:
            # D_TP >= P_TP: each local rank reads from one remote rank.
            # Multiple local ranks may map to the same remote when D_TP > P_TP.
            remote_rank = local_tp_rank * remote_tp_size // local_tp_size
            # Compute which sub-range of the remote's heads this local rank owns.
            # When D_TP > P_TP and total >= local_tp, different local ranks
            # index into different head offsets within the same remote block.
            local_start = local_tp_rank * total // local_tp_size
            local_end = (local_tp_rank + 1) * total // local_tp_size
            local_heads = max(1, local_end - local_start)

            remote_head_start = remote_rank * total // remote_tp_size
            offset = local_start - remote_head_start
            return [
                TPTransferSlice(
                    remote_rank=remote_rank,
                    global_range=range(local_start, local_start + local_heads),
                    local_range=range(0, local_heads),
                    remote_range=range(offset, offset + local_heads),
                )
            ]
        else:
            # P_TP > D_TP: one local rank reads from multiple remote ranks.
            # GQA dedup: when total_kv_heads < remote_tp, several remote ranks
            # hold the same head. Only read from unique ones.
            start = local_tp_rank * remote_tp_size // local_tp_size
            end = (local_tp_rank + 1) * remote_tp_size // local_tp_size

            local_start = local_tp_rank * total // local_tp_size

            slices: list[TPTransferSlice] = []
            seen_heads: set[int] = set()
            for r in range(start, end):
                remote_start = r * total // remote_tp_size
                remote_end = (r + 1) * total // remote_tp_size
                remote_heads = remote_end - remote_start

                if remote_heads == 0 or remote_start in seen_heads:
                    continue
                seen_heads.add(remote_start)

                slices.append(
                    TPTransferSlice(
                        remote_rank=r,
                        global_range=range(remote_start, remote_end),
                        local_range=range(
                            remote_start - local_start,
                            remote_end - local_start,
                        ),
                        remote_range=range(0, remote_heads),
                    )
                )

            return slices

Comment thread vllm/v1/kv_cache_interface.py Outdated
Comment on lines +232 to +311
def get_tp_transfer_slices(
self,
local_tp_rank: int,
local_tp_size: int,
remote_tp_size: int,
) -> list[TPTransferSlice]:
"""Compute transfer slices for this local rank.

Each slice describes one read: 'read remote_rank's remote_range
and place it in my local_range'.

Handles:
- Homogeneous TP (local_tp == remote_tp): 1:1 mapping
- D_TP > P_TP: local rank reads a sub-range from one remote
- P_TP > D_TP: local rank reads from multiple remotes
- GQA replication (total_kv_heads < remote_tp): load-balanced
remote selection to avoid redundant reads
"""
assert self.total_num_kv_heads is not None, (
"total_num_kv_heads must be set for TP mapping. "
"Pass it when constructing the spec via get_kv_cache_spec()."
)
total = self.total_num_kv_heads

if local_tp_size >= remote_tp_size:
# D_TP >= P_TP: each local rank reads from one remote rank.
# Multiple local ranks may map to the same remote when D_TP > P_TP.
remote_rank = local_tp_rank * remote_tp_size // local_tp_size
remote_heads = max(1, total // remote_tp_size)
# Compute which sub-range of the remote's heads this local rank owns.
# When D_TP > P_TP and total >= local_tp, different local ranks
# index into different head offsets within the same remote block.
local_head = local_tp_rank * total // local_tp_size
remote_head_start = remote_rank * total // remote_tp_size
offset = local_head - remote_head_start
local_heads = max(1, total // local_tp_size)
return [
TPTransferSlice(
remote_rank=remote_rank,
global_range=range(local_head, local_head + local_heads),
local_range=range(0, local_heads),
remote_range=range(offset, offset + local_heads),
)
]
else:
# P_TP > D_TP: one local rank reads from multiple remote ranks.
# GQA dedup: when total_kv_heads < remote_tp, several remote ranks
# hold the same head. Only read from unique ones.
abs_tp = remote_tp_size // local_tp_size
start = local_tp_rank * abs_tp

local_start = local_tp_rank * total // local_tp_size
local_end = (local_tp_rank + 1) * total // local_tp_size
local_heads = local_end - local_start

slices: list[TPTransferSlice] = []
seen_heads: set[int] = set()
for r in range(start, start + abs_tp):
head = r * total // remote_tp_size
if head in seen_heads:
continue
seen_heads.add(head)

remote_start = r * total // remote_tp_size
remote_end = (r + 1) * total // remote_tp_size
remote_heads = remote_end - remote_start

slices.append(
TPTransferSlice(
remote_rank=r,
global_range=range(remote_start, remote_end),
local_range=range(
remote_start - local_start,
remote_end - local_start,
),
remote_range=range(0, remote_heads),
)
)

return slices

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting starting on this!!

with something like

def slice_for_tp_transfer(self, src_tensor, src_tp, src_rank,
                          dst_spec, dst_tp, dst_rank):

    src_heads = src_tp*self.num_heads
    dst_heads = dst_tp*dst_spec.num_heads

    if dst_heads > src_heads:
        dst_rank %= src_tp
    elif src_rank >= dst_tp:
        return []

    head_range = lambda tp_rank, num_heads: (tp_rank * num_heads, (tp_rank + 1) * num_heads)

    src_head_range = head_range(src_rank, self.num_heads)
    dst_head_range = head_range(dst_rank, dst_spec.num_heads)

    overlap = get_overlap(src_head_range, dst_head_range)
    if overlap is None:
        return []

    # Convert global overlap to local tensor indices
    local_slice = slice(overlap.start - src_head_range[0],
                        overlap.stop - src_head_range[0])
    return [src_tensor[:, :, local_slice]]

def get_overlap(range1, range2):
    # Calculate the intersection boundaries
    overlap_start = max(range1[0], range2[0])
    overlap_end = min(range1[1], range2[1])
    
    # Check if the boundaries form a valid range
    if overlap_start <= overlap_end:
        return slice(overlap_start, overlap_end)
    return None

I think we can avoid the need to total_num_kv_heads to the spec? this should handled replicated kv-heads on the remote and/or local just fine I think (unless im missing something)

I could understand if this is too confusing though;

@ZhanqiuHu ZhanqiuHu May 20, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using total_num_kv_heads to check replication might be cleaner, but we can pass in as an argument (from model config) instead of storing it in kv cache config.

I think it might be a little indirect to use src_heads = src_tp*self.num_heads and dst_heads = dst_tp*dst_spec.num_heads to recover whether one side is more replicated then the other?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good; i do like passing into get_tp_transfer_slices instead of passing it via the spec, this is much cleaner thank you! I think we should just pass in model_config though to make it less "standard attention" centric since this will also be for mamba

Comment thread vllm/v1/kv_cache_interface.py Outdated
Comment on lines +95 to +97
global_range: range # position in the model's full head space
local_range: range # position in the local worker's tensor
remote_range: range # position in the remote worker's tensor

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use slice?



@dataclass(frozen=True)
class TPTransferSlice:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes total sense for now but I think we can just ultimately have get_tp_transfer_slices slice the tensor directly? once #42374 lands?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu was telling she had some issues implementing it with tensors.
I would give it another go mapping TPTransferSlice to tensor code to see if there are cases which are not covered by simple tensor slicing (or if code gets somehow messier).
I'll also take a look.

@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review May 28, 2026 12:45
Comment thread vllm/v1/kv_cache_interface.py Outdated
Comment thread vllm/v1/kv_cache_interface.py Outdated
# TP transfer slice interface (MLA: cache is always replicated)
# ------------------------------------------------------------------

def get_tp_transfer_slices(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is identical to MLAAttentionSpec's implementation, should we factor it out into a helper?

Comment on lines -52 to -57
def get_representative_spec_type(spec: KVCacheSpec) -> type[KVCacheSpec]:
if isinstance(spec, UniformTypeKVCacheSpecs):
# All inner specs are the same type; pick any.
inner = next(iter(spec.kv_cache_specs.values()))
return type(inner)
return type(spec)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is handled for UniformTypeKVCacheSpecs right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now it's moved to _get_representative_spec

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to slice with tensors, do we also need to know the attention backend layout? Currently the TP mapping logic is attention backend agnostic.



@dataclass(frozen=True)
class TPTransferSlice:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu was telling she had some issues implementing it with tensors.
I would give it another go mapping TPTransferSlice to tensor code to see if there are cases which are not covered by simple tensor slicing (or if code gets somehow messier).
I'll also take a look.

@MatthewBonanni MatthewBonanni changed the title [NIXL] Cache TP slicing and mapping redesign [NIXL][2/N] Cache TP slicing and mapping redesign May 29, 2026
ZhanqiuHu added 10 commits June 2, 2026 16:38
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
@mergify

mergify Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants