[NIXL][2/N] Cache TP slicing and mapping redesign#43151
Conversation
1ea360e to
f16f573
Compare
There was a problem hiding this comment.
Code Review
This pull request refactors the Tensor Parallel (TP) mapping logic for NIXL KV cache transfers by introducing a more robust TPTransferSlice structure and moving the mapping computation into the KVCacheSpec classes. It also updates various attention layers to provide total KV head counts and refactors the worker and test suites to accommodate these changes. Feedback highlights critical issues in the integration test script regarding hardcoded GPU arrays, unused variables in the worker, and several algorithmic flaws in the TP slicing logic, including non-robust partitioning and potential bugs in GQA replication.
I am having trouble creating individual review comments. Click here to see my feedback.
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh (83-84)
The hardcoded PREFILL_GPUS and DECODE_GPUS arrays are too small to support the TP sizes mentioned in the PR description (e.g., TP=4). For instance, if PREFILLER_TP_SIZE=4, the script will attempt to access PREFILL_GPUS[3], which is out of bounds for (1 2 4). This will cause the test script to crash or produce invalid CUDA_VISIBLE_DEVICES strings.
PREFILL_GPUS=(0 1 2 3)
DECODE_GPUS=(4 5 6 7)
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py (1181)
The variable remote_heads_per_rank is calculated using a non-robust division (//) and is never used in the logic (only in a log message). It should be removed or fixed to use robust partitioning ((i+1)*total//size - i*total//size) if it's intended for future use.
vllm/v1/kv_cache_interface.py (256-311)
The TP slicing logic has several issues:
- Non-robust partitioning: Using
total // sizefails whentotalis not divisible bysize. It should use the difference between offsets:(i+1)*total//size - i*total//size. - GQA Replication Bug: In the
P_TP > D_TPbranch, iftotal < remote_tp_size, some remote ranks will have 0 heads. The current logic might pick a rank with 0 heads and skip subsequent ranks that actually contain the head data because they share the sameheadindex (calculated as 0). It should skip ranks whereremote_start == remote_end. - Unused Variables:
remote_heads(line 260) andlocal_heads(line 285) are calculated but not used in their respective branches. - Non-multiple TP sizes: The
abs_tplogic assumesremote_tp_sizeis a multiple oflocal_tp_size. Using robust range boundariesstart = local_tp_rank * remote_tp_size // local_tp_sizeis safer.
if local_tp_size >= remote_tp_size:
# D_TP >= P_TP: each local rank reads from one remote rank.
# Multiple local ranks may map to the same remote when D_TP > P_TP.
remote_rank = local_tp_rank * remote_tp_size // local_tp_size
# Compute which sub-range of the remote's heads this local rank owns.
# When D_TP > P_TP and total >= local_tp, different local ranks
# index into different head offsets within the same remote block.
local_start = local_tp_rank * total // local_tp_size
local_end = (local_tp_rank + 1) * total // local_tp_size
local_heads = max(1, local_end - local_start)
remote_head_start = remote_rank * total // remote_tp_size
offset = local_start - remote_head_start
return [
TPTransferSlice(
remote_rank=remote_rank,
global_range=range(local_start, local_start + local_heads),
local_range=range(0, local_heads),
remote_range=range(offset, offset + local_heads),
)
]
else:
# P_TP > D_TP: one local rank reads from multiple remote ranks.
# GQA dedup: when total_kv_heads < remote_tp, several remote ranks
# hold the same head. Only read from unique ones.
start = local_tp_rank * remote_tp_size // local_tp_size
end = (local_tp_rank + 1) * remote_tp_size // local_tp_size
local_start = local_tp_rank * total // local_tp_size
slices: list[TPTransferSlice] = []
seen_heads: set[int] = set()
for r in range(start, end):
remote_start = r * total // remote_tp_size
remote_end = (r + 1) * total // remote_tp_size
remote_heads = remote_end - remote_start
if remote_heads == 0 or remote_start in seen_heads:
continue
seen_heads.add(remote_start)
slices.append(
TPTransferSlice(
remote_rank=r,
global_range=range(remote_start, remote_end),
local_range=range(
remote_start - local_start,
remote_end - local_start,
),
remote_range=range(0, remote_heads),
)
)
return slices| def get_tp_transfer_slices( | ||
| self, | ||
| local_tp_rank: int, | ||
| local_tp_size: int, | ||
| remote_tp_size: int, | ||
| ) -> list[TPTransferSlice]: | ||
| """Compute transfer slices for this local rank. | ||
|
|
||
| Each slice describes one read: 'read remote_rank's remote_range | ||
| and place it in my local_range'. | ||
|
|
||
| Handles: | ||
| - Homogeneous TP (local_tp == remote_tp): 1:1 mapping | ||
| - D_TP > P_TP: local rank reads a sub-range from one remote | ||
| - P_TP > D_TP: local rank reads from multiple remotes | ||
| - GQA replication (total_kv_heads < remote_tp): load-balanced | ||
| remote selection to avoid redundant reads | ||
| """ | ||
| assert self.total_num_kv_heads is not None, ( | ||
| "total_num_kv_heads must be set for TP mapping. " | ||
| "Pass it when constructing the spec via get_kv_cache_spec()." | ||
| ) | ||
| total = self.total_num_kv_heads | ||
|
|
||
| if local_tp_size >= remote_tp_size: | ||
| # D_TP >= P_TP: each local rank reads from one remote rank. | ||
| # Multiple local ranks may map to the same remote when D_TP > P_TP. | ||
| remote_rank = local_tp_rank * remote_tp_size // local_tp_size | ||
| remote_heads = max(1, total // remote_tp_size) | ||
| # Compute which sub-range of the remote's heads this local rank owns. | ||
| # When D_TP > P_TP and total >= local_tp, different local ranks | ||
| # index into different head offsets within the same remote block. | ||
| local_head = local_tp_rank * total // local_tp_size | ||
| remote_head_start = remote_rank * total // remote_tp_size | ||
| offset = local_head - remote_head_start | ||
| local_heads = max(1, total // local_tp_size) | ||
| return [ | ||
| TPTransferSlice( | ||
| remote_rank=remote_rank, | ||
| global_range=range(local_head, local_head + local_heads), | ||
| local_range=range(0, local_heads), | ||
| remote_range=range(offset, offset + local_heads), | ||
| ) | ||
| ] | ||
| else: | ||
| # P_TP > D_TP: one local rank reads from multiple remote ranks. | ||
| # GQA dedup: when total_kv_heads < remote_tp, several remote ranks | ||
| # hold the same head. Only read from unique ones. | ||
| abs_tp = remote_tp_size // local_tp_size | ||
| start = local_tp_rank * abs_tp | ||
|
|
||
| local_start = local_tp_rank * total // local_tp_size | ||
| local_end = (local_tp_rank + 1) * total // local_tp_size | ||
| local_heads = local_end - local_start | ||
|
|
||
| slices: list[TPTransferSlice] = [] | ||
| seen_heads: set[int] = set() | ||
| for r in range(start, start + abs_tp): | ||
| head = r * total // remote_tp_size | ||
| if head in seen_heads: | ||
| continue | ||
| seen_heads.add(head) | ||
|
|
||
| remote_start = r * total // remote_tp_size | ||
| remote_end = (r + 1) * total // remote_tp_size | ||
| remote_heads = remote_end - remote_start | ||
|
|
||
| slices.append( | ||
| TPTransferSlice( | ||
| remote_rank=r, | ||
| global_range=range(remote_start, remote_end), | ||
| local_range=range( | ||
| remote_start - local_start, | ||
| remote_end - local_start, | ||
| ), | ||
| remote_range=range(0, remote_heads), | ||
| ) | ||
| ) | ||
|
|
||
| return slices |
There was a problem hiding this comment.
Thanks for getting starting on this!!
with something like
def slice_for_tp_transfer(self, src_tensor, src_tp, src_rank,
dst_spec, dst_tp, dst_rank):
src_heads = src_tp*self.num_heads
dst_heads = dst_tp*dst_spec.num_heads
if dst_heads > src_heads:
dst_rank %= src_tp
elif src_rank >= dst_tp:
return []
head_range = lambda tp_rank, num_heads: (tp_rank * num_heads, (tp_rank + 1) * num_heads)
src_head_range = head_range(src_rank, self.num_heads)
dst_head_range = head_range(dst_rank, dst_spec.num_heads)
overlap = get_overlap(src_head_range, dst_head_range)
if overlap is None:
return []
# Convert global overlap to local tensor indices
local_slice = slice(overlap.start - src_head_range[0],
overlap.stop - src_head_range[0])
return [src_tensor[:, :, local_slice]]
def get_overlap(range1, range2):
# Calculate the intersection boundaries
overlap_start = max(range1[0], range2[0])
overlap_end = min(range1[1], range2[1])
# Check if the boundaries form a valid range
if overlap_start <= overlap_end:
return slice(overlap_start, overlap_end)
return None
I think we can avoid the need to total_num_kv_heads to the spec? this should handled replicated kv-heads on the remote and/or local just fine I think (unless im missing something)
I could understand if this is too confusing though;
There was a problem hiding this comment.
I think using total_num_kv_heads to check replication might be cleaner, but we can pass in as an argument (from model config) instead of storing it in kv cache config.
I think it might be a little indirect to use src_heads = src_tp*self.num_heads and dst_heads = dst_tp*dst_spec.num_heads to recover whether one side is more replicated then the other?
There was a problem hiding this comment.
sounds good; i do like passing into get_tp_transfer_slices instead of passing it via the spec, this is much cleaner thank you! I think we should just pass in model_config though to make it less "standard attention" centric since this will also be for mamba
| global_range: range # position in the model's full head space | ||
| local_range: range # position in the local worker's tensor | ||
| remote_range: range # position in the remote worker's tensor |
There was a problem hiding this comment.
should we use slice?
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TPTransferSlice: |
There was a problem hiding this comment.
this makes total sense for now but I think we can just ultimately have get_tp_transfer_slices slice the tensor directly? once #42374 lands?
There was a problem hiding this comment.
@ZhanqiuHu was telling she had some issues implementing it with tensors.
I would give it another go mapping TPTransferSlice to tensor code to see if there are cases which are not covered by simple tensor slicing (or if code gets somehow messier).
I'll also take a look.
| # TP transfer slice interface (MLA: cache is always replicated) | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def get_tp_transfer_slices( |
There was a problem hiding this comment.
Since this is identical to MLAAttentionSpec's implementation, should we factor it out into a helper?
| def get_representative_spec_type(spec: KVCacheSpec) -> type[KVCacheSpec]: | ||
| if isinstance(spec, UniformTypeKVCacheSpecs): | ||
| # All inner specs are the same type; pick any. | ||
| inner = next(iter(spec.kv_cache_specs.values())) | ||
| return type(inner) | ||
| return type(spec) |
There was a problem hiding this comment.
this is handled for UniformTypeKVCacheSpecs right?
There was a problem hiding this comment.
Yes, now it's moved to _get_representative_spec
There was a problem hiding this comment.
If we want to slice with tensors, do we also need to know the attention backend layout? Currently the TP mapping logic is attention backend agnostic.
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TPTransferSlice: |
There was a problem hiding this comment.
@ZhanqiuHu was telling she had some issues implementing it with tensors.
I would give it another go mapping TPTransferSlice to tensor code to see if there are cases which are not covered by simple tensor slicing (or if code gets somehow messier).
I'll also take a look.
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
33168ae to
03eca0d
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
RFC: #42082