Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def test_read_blocks_for_req_expands_remote_ids(
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import (
TPMapping,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
Expand Down Expand Up @@ -172,18 +169,18 @@ def test_read_blocks_for_req_expands_remote_ids(
remote_engine_id = "remote-engine"

worker.transfer_topo = MagicMock()
# tp_ratio not exercised (all_source_ranks is empty so no reads run),
# tp_ratio not exercised (remote_ranks is empty so no reads run),
# but set for realism.
worker.transfer_topo.tp_ratio.return_value = tp_ratio
remote_info = MagicMock()
remote_info.remote_physical_blocks_per_logical = remote_physical_per_logical
worker.transfer_topo.get_engine_info.return_value = remote_info
worker.use_mla = False

mock_plan = MagicMock(spec=TPMapping)
mock_plan.all_source_ranks = ()
mock_plan.source_ranks_per_group = ()
worker.tp_mappings = {remote_engine_id: mock_plan}
# Empty tp_mappings: no source ranks so no reads are issued.
num_groups = len(resolved_types)
worker.tp_mappings = {remote_engine_id: tuple({} for _ in range(num_groups))}
worker.remote_ranks = {remote_engine_id: ()}

metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
Expand Down Expand Up @@ -346,9 +343,6 @@ def test_mismatched_physical_per_logical_fails_with_prefix_caching(
mamba_enabled=True,
)
worker._has_mamba = True
worker._group_spec_types = tuple(
type(g.kv_cache_spec) for g in worker.kv_cache_config.kv_cache_groups
)

local_block_ids = (local_fa_blocks, ssm_blocks)
remote_block_ids = (remote_fa_blocks, ssm_blocks)
Expand Down
263 changes: 201 additions & 62 deletions tests/v1/kv_connector/unit/test_tp_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,97 +9,153 @@

from __future__ import annotations

from types import SimpleNamespace
from unittest.mock import MagicMock

import pytest
import torch

from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import (
TPMapping,
compute_tp_mapping,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
ShardRange,
TPTransferSlice,
)

# ======================================================================
# Test fixtures / helpers
# ======================================================================


def _compute_mapping(
def _make_fa_spec(num_kv_heads: int = 4):
return FullAttentionSpec(
block_size=16,
num_kv_heads=num_kv_heads,
head_size=128,
head_size_v=128,
dtype=torch.float16,
)


def _get_slices(
tp_rank: int = 0,
tp_size: int = 1,
remote_tp_size: int = 1,
is_mla: bool = False,
num_kv_heads: int = 8,
group_spec_types: tuple[type, ...] = (FullAttentionSpec,),
) -> TPMapping:
transfer_topology = SimpleNamespace(
tp_rank=tp_rank,
tp_size=tp_size,
is_mla=is_mla,
total_num_kv_heads=num_kv_heads,
)
return compute_tp_mapping(
transfer_topology=transfer_topology,
remote_tp_size=remote_tp_size,
group_spec_types=group_spec_types,
total_num_kv_heads: int = 8,
spec=None,
) -> dict[int, TPTransferSlice]:
"""Call get_tp_transfer_slices on the given spec (or a default FA spec)."""
if spec is None:
num_kv_heads = max(1, total_num_kv_heads // tp_size)
spec = _make_fa_spec(num_kv_heads)
return spec.get_tp_transfer_slices(
tp_rank, tp_size, remote_tp_size, total_num_kv_heads
)


def _remote_ranks_from_slices(
*group_slices: dict[int, TPTransferSlice],
) -> tuple[int, ...]:
"""Derive deduplicated sorted source ranks from multiple group slices."""
return tuple(sorted({r for slices in group_slices for r in slices}))


# ======================================================================
# TP mapping structure tests
# ======================================================================


class TestTPMappingStructure:
def test_source_ranks_homogeneous(self):
m = _compute_mapping(tp_size=2, tp_rank=1, remote_tp_size=2)
assert m.all_source_ranks == (1,)
def test_remote_ranks_homogeneous(self):
slices = _get_slices(tp_size=2, tp_rank=1, remote_tp_size=2)
assert _remote_ranks_from_slices(slices) == (1,)

def test_source_ranks_d_gt_p(self):
m = _compute_mapping(tp_size=4, tp_rank=2, remote_tp_size=2)
assert m.all_source_ranks == (1,)
def test_remote_ranks_d_gt_p(self):
slices = _get_slices(tp_size=4, tp_rank=2, remote_tp_size=2)
assert _remote_ranks_from_slices(slices) == (1,)

def test_source_ranks_p_gt_d(self):
m = _compute_mapping(tp_size=1, tp_rank=0, remote_tp_size=2)
assert m.all_source_ranks == (0, 1)
def test_remote_ranks_p_gt_d(self):
slices = _get_slices(tp_size=1, tp_rank=0, remote_tp_size=2)
assert _remote_ranks_from_slices(slices) == (0, 1)

def test_per_group_slices(self):
slices = _get_slices(tp_size=2, tp_rank=0, remote_tp_size=4)
assert len(slices) == 2
assert 0 in slices
assert 1 in slices

def test_has_rank_in_group(self):
slices = _get_slices(tp_size=1, tp_rank=0, remote_tp_size=2)
assert 0 in slices
assert 1 in slices
assert 2 not in slices

def test_gqa_dedup_load_balanced(self):
"""With total_heads=2, remote_tp=4: picks aligned remote ranks."""
slices_r0 = _get_slices(
tp_size=2, tp_rank=0, remote_tp_size=4, total_num_kv_heads=2
)
slices_r1 = _get_slices(
tp_size=2, tp_rank=1, remote_tp_size=4, total_num_kv_heads=2
)
assert 0 in slices_r0
assert 2 in slices_r1


# ======================================================================
# Split handle tests
# ======================================================================


def _make_mock_worker_for_splits(group_spec_types):
"""Build a mock NixlConnectorWorker with _group_spec_types for split tests."""
def _make_mock_worker_for_splits(
group_specs: list,
tp_mappings: tuple,
remote_ranks: tuple[int, ...],
engine_id: str = "remote_0",
):
"""Build a mock NixlConnectorWorker with the fields _build_local_splits needs."""
worker = object.__new__(NixlConnectorWorker)
worker._group_spec_types = group_spec_types
kv_cache_groups = []
for spec in group_specs:
group = MagicMock()
group.kv_cache_spec = spec
kv_cache_groups.append(group)
kv_cache_config = MagicMock()
kv_cache_config.kv_cache_groups = kv_cache_groups
worker.kv_cache_config = kv_cache_config
worker.tp_mappings = {engine_id: tp_mappings}
worker.remote_ranks = {engine_id: remote_ranks}
worker.transfer_topo = MagicMock()
return worker


class TestBuildSrcSplitHandles:
@pytest.mark.parametrize("remote_tp_size", [2, 4])
def test_build_src_split_handles(self, remote_tp_size):
def test_split_shape(self, remote_tp_size):
"""Each split has correct number of descs with correct chunk size."""
tp_rank = 0
tp_size = 1
total_num_kv_heads = 8
engine_id = "remote_0"

plan = _compute_mapping(
tp_rank=tp_rank,
tp_size=tp_size,
remote_tp_size=remote_tp_size,
fa_spec = _make_fa_spec(num_kv_heads=total_num_kv_heads // tp_size)
fa_slices = fa_spec.get_tp_transfer_slices(
tp_rank, tp_size, remote_tp_size, total_num_kv_heads
)
remote_ranks = _remote_ranks_from_slices(fa_slices)

worker = _make_mock_worker_for_splits((FullAttentionSpec,))
worker = _make_mock_worker_for_splits(
group_specs=[fa_spec],
tp_mappings=(fa_slices,),
remote_ranks=remote_ranks,
engine_id=engine_id,
)
src_blocks_data = [(0x2000 + i * 1024, 1024, 0) for i in range(8)]
num_descs = len(src_blocks_data)
num_fa_descs = len(src_blocks_data)
splits = list(
worker._build_local_splits_from_plan(
plan,
src_blocks_data,
num_descs,
)
worker._build_local_splits(engine_id, src_blocks_data, num_fa_descs)
)

assert len(splits) == remote_tp_size
Expand All @@ -108,39 +164,122 @@ def test_build_src_split_handles(self, remote_tp_size):
for _, length, _ in handle:
assert length == 1024 // remote_tp_size

@pytest.mark.parametrize(
"remote_tp_size,total_num_kv_heads",
[(2, 4), (2, 8), (4, 8)],
)
def test_fa_offsets_p_gt_d(self, remote_tp_size, total_num_kv_heads):
"""Verify concrete FA offsets for multi-head P>D (the previously buggy path).

With local_tp=1, the full local block covers all heads. Each remote
rank's slice should land at the correct byte offset proportional to
its position in the local tensor.
"""
tp_rank = 0
tp_size = 1
engine_id = "remote_0"
local_block_len = 1024

fa_spec = _make_fa_spec(num_kv_heads=total_num_kv_heads // tp_size)
fa_slices = fa_spec.get_tp_transfer_slices(
tp_rank, tp_size, remote_tp_size, total_num_kv_heads
)
remote_ranks = _remote_ranks_from_slices(fa_slices)

worker = _make_mock_worker_for_splits(
group_specs=[fa_spec],
tp_mappings=(fa_slices,),
remote_ranks=remote_ranks,
engine_id=engine_id,
)
base_addr = 0x4000
src_blocks_data = [(base_addr, local_block_len, 0)]
splits = list(worker._build_local_splits(engine_id, src_blocks_data, 1))

assert len(splits) == remote_tp_size
chunk = local_block_len // remote_tp_size
for idx, (rank, sl) in enumerate(sorted(fa_slices.items())):
expected_offset = (
sl.local_write_offset * local_block_len // len(sl.local_shard)
)
# Offsets should tile the local block without overlap
assert expected_offset == idx * chunk
addr, length, dev = splits[idx][0]
assert addr == base_addr + expected_offset
assert length == chunk
assert dev == 0


class TestMambaPlanSplitHandles:
"""Verify split handles for Mamba with FA/SSM distinction."""

def test_fa_and_ssm_different_split_factors(self):
"""Section 0 split by num_attn_reads, section 1 by abs_tp."""
fa_readers = (0,)
ssm_readers = (0, 1)
plan = TPMapping(
source_ranks_per_group=(fa_readers, ssm_readers),
all_source_ranks=(0, 1),
rank_to_attention_slot={0: 0, 1: 0},
rank_offset_factor=0,
engine_id = "remote_0"
# total_kv_heads=1 < remote_tp=2 triggers GQA dedup:
# only remote rank 0 holds unique FA data.
total_num_kv_heads = 1

fa_spec = _make_fa_spec(num_kv_heads=1)
mamba_spec = MagicMock(spec=MambaSpec)

# local_tp=1, remote_tp=2
# FA: 1 unique slice (reads from remote 0, GQA dedup skips rank 1)
# Mamba: 2 slices (reads from remote 0 and 1)
fa_slices = fa_spec.get_tp_transfer_slices(0, 1, 2, total_num_kv_heads)

shard_mamba = ShardRange(0, 1, 1)
ssm_slices = {
0: TPTransferSlice(
remote_rank=0,
remote_shard=shard_mamba,
local_shard=shard_mamba,
transfer_range=shard_mamba,
),
1: TPTransferSlice(
remote_rank=1,
remote_shard=shard_mamba,
local_shard=shard_mamba,
transfer_range=shard_mamba,
),
}
remote_ranks = _remote_ranks_from_slices(fa_slices, ssm_slices)

worker = _make_mock_worker_for_splits(
group_specs=[fa_spec, mamba_spec],
tp_mappings=(fa_slices, ssm_slices),
remote_ranks=remote_ranks,
engine_id=engine_id,
)

worker = _make_mock_worker_for_splits((FullAttentionSpec, MambaSpec))
# 2 FA descs + 1 SSM desc
src_blocks_data = [
(1000, 200, 0), # FA desc 0
(2000, 200, 0), # FA desc 1
(3000, 400, 0), # SSM desc 0
]

splits = list(worker._build_local_splits_from_plan(plan, src_blocks_data, 2))
splits = list(worker._build_local_splits(engine_id, src_blocks_data, 2))

assert len(splits) == 2 # 2 source ranks

# Rank 0 (FA source, p_idx=0):
# FA: chunk=200//1=200, slot=0 → (1000, 200, 0), (2000, 200, 0)
# SSM: chunk=400//2=200, idx=0 → (3000, 200, 0)
assert splits[0] == [(1000, 200, 0), (2000, 200, 0), (3000, 200, 0)]
# Rank 0 is in fa_slices -> uses local_write_offset for FA offset
fa_chunk = 200 // len(fa_slices)
ssm_chunk = 400 // len(ssm_slices)

# Rank 0 (remote_idx=0):
# FA: chunk=200//1=200 (only 1 FA slice)
# offset = local_write_offset * local_block_len // len(local_shard)
# SSM: chunk=400//2=200, offset = remote_idx(0) * 200
sl = fa_slices[0]
fa_offset_r0 = sl.local_write_offset * 200 // len(sl.local_shard)
assert splits[0][0] == (1000 + fa_offset_r0, fa_chunk, 0)
assert splits[0][1] == (2000 + fa_offset_r0, fa_chunk, 0)
assert splits[0][2] == (3000 + 0 * ssm_chunk, ssm_chunk, 0)

# Rank 1 (not FA source, p_idx=1):
# FA: chunk=200//1=200, slot=0 (skip_fa) → (1000, 200, 0), (2000, 200, 0)
# SSM: chunk=400//2=200, idx=1 → (3200, 200, 0)
assert splits[1] == [(1000, 200, 0), (2000, 200, 0), (3200, 200, 0)]
# Rank 1 (remote_idx=1):
# FA: rank 1 NOT in fa_slices -> GQA-deduped placeholder (addr, chunk, dev)
# SSM: chunk=400//2=200, offset = remote_idx(1) * 200
assert splits[1][0] == (1000, fa_chunk, 0)
assert splits[1][1] == (2000, fa_chunk, 0)
assert splits[1][2] == (3000 + 1 * ssm_chunk, ssm_chunk, 0)
Loading
Loading