Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid(
)


@pytest.mark.cpu_test
@pytest.mark.parametrize(
"local_physical_per_logical,remote_physical_per_logical,"
"local_block_ids,remote_block_ids,"
"expected_local,expected_remote",
[
# SSM prefix caching: remote has 3 placeholder + 1 real block,
# local has only the 1 real block. FA blocks are equal (no trim).
pytest.param(
10,
10,
[list(range(10)), [42]],
[list(range(10)), [40, 41, 42, 43]],
[list(range(10)), [42]],
[list(range(10)), [43]],
id="ssm_prefix_trim_only",
),
# FA partial prefix cache hit with homogeneous TP: local has 4 FA
# blocks (prefix cached), remote has full 10. SSM equal (no trim).
pytest.param(
10,
10,
[list(range(6, 10)), [42]],
[list(range(10)), [42]],
[list(range(6, 10)), [42]],
[list(range(6, 10)), [42]],
id="fa_prefix_hit_homo_tp",
),
# Both: FA partial prefix hit + SSM placeholder trim.
# local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9]
# local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real)
pytest.param(
10,
10,
[[6, 7, 8, 9], [99]],
[list(range(10)), [10, 20, 99]],
[[6, 7, 8, 9], [99]],
[[6, 7, 8, 9], [99]],
id="fa_prefix_hit_and_ssm_trim",
),
],
)
def test_apply_prefix_caching_ssm_prefix_cache_hit(
local_physical_per_logical,
remote_physical_per_logical,
local_block_ids,
remote_block_ids,
expected_local,
expected_remote,
):
"""_apply_prefix_caching end-trims SSM remote blocks to match the single
local block (placeholders dropped) and end-trims FA remote blocks on
partial prefix cache hits when physical_per_logical matches.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec

worker = object.__new__(NixlConnectorWorker)
worker._has_mamba = True
worker._physical_blocks_per_logical_kv_block = local_physical_per_logical
worker._group_spec_types = (FullAttentionSpec, MambaSpec)
worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True)

aligned_local, aligned_remote = worker._apply_prefix_caching(
local_block_ids, remote_block_ids, remote_physical_per_logical
)

assert aligned_local == expected_local, (
f"Expected local {expected_local}, got {aligned_local}"
)
assert aligned_remote == expected_remote, (
f"Expected remote {expected_remote}, got {aligned_remote}"
)


@pytest.mark.cpu_test
@pytest.mark.parametrize(
"local_physical_per_logical,remote_physical_per_logical,"
Expand Down
20 changes: 18 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,9 +2333,25 @@ def _apply_prefix_caching(
for i, remote_group in enumerate(remote_block_ids):
num_local_blocks = len(local_block_ids[i])
num_remote_blocks = len(remote_group)
if _is_ssm_spec(self._group_spec_types[i]):
assert num_local_blocks == num_remote_blocks
if (
_is_ssm_spec(self._group_spec_types[i])
and num_local_blocks < num_remote_blocks
):
# NOTE (NickLucche): With prefix caching on SSM, (remote) blocks
# prior to the last one are placeholders (null blocks). Mind that
# this doesn't really impact transfer, as we only still care about
# the last "block", the full in-place state.
assert num_local_blocks == 1, "SSM can only have one local block"
remote_block_ids[i] = remote_group[-num_local_blocks:]
Comment thread
NickLucche marked this conversation as resolved.
elif (
self._physical_blocks_per_logical_kv_block
== remote_physical_per_logical
and num_local_blocks < num_remote_blocks
):
# Partial prefix cache hit for FA group.
remote_block_ids[i] = remote_group[-num_local_blocks:]
Comment thread
NickLucche marked this conversation as resolved.
else:
# TODO Handle prefix caching with different block_sizes
max_padding = max(
self._physical_blocks_per_logical_kv_block,
remote_physical_per_logical,
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,6 @@ def _mamba_block_aligned_split(
num_new_local_computed_tokens: int = 0,
num_external_computed_tokens: int = 0,
) -> int:
assert num_external_computed_tokens == 0, (
"External KV connector is not verified yet"
)
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
Expand Down Expand Up @@ -715,7 +712,8 @@ def schedule(self) -> SchedulerOutput:
# The request cannot be scheduled.
break

if self.need_mamba_block_aligned_split:
# Skip block alignment when setting up async receive (no local work).
if self.need_mamba_block_aligned_split and not load_kv_async:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
Expand Down
Loading