diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6d4e6565e373..06b16a3e3e8b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid( ) +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "local_physical_per_logical,remote_physical_per_logical," + "local_block_ids,remote_block_ids," + "expected_local,expected_remote", + [ + # SSM prefix caching: remote has 3 placeholder + 1 real block, + # local has only the 1 real block. FA blocks are equal (no trim). + pytest.param( + 10, + 10, + [list(range(10)), [42]], + [list(range(10)), [40, 41, 42, 43]], + [list(range(10)), [42]], + [list(range(10)), [43]], + id="ssm_prefix_trim_only", + ), + # FA partial prefix cache hit with homogeneous TP: local has 4 FA + # blocks (prefix cached), remote has full 10. SSM equal (no trim). + pytest.param( + 10, + 10, + [list(range(6, 10)), [42]], + [list(range(10)), [42]], + [list(range(6, 10)), [42]], + [list(range(6, 10)), [42]], + id="fa_prefix_hit_homo_tp", + ), + # Both: FA partial prefix hit + SSM placeholder trim. + # local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9] + # local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real) + pytest.param( + 10, + 10, + [[6, 7, 8, 9], [99]], + [list(range(10)), [10, 20, 99]], + [[6, 7, 8, 9], [99]], + [[6, 7, 8, 9], [99]], + id="fa_prefix_hit_and_ssm_trim", + ), + ], +) +def test_apply_prefix_caching_ssm_prefix_cache_hit( + local_physical_per_logical, + remote_physical_per_logical, + local_block_ids, + remote_block_ids, + expected_local, + expected_remote, +): + """_apply_prefix_caching end-trims SSM remote blocks to match the single + local block (placeholders dropped) and end-trims FA remote blocks on + partial prefix cache hits when physical_per_logical matches. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, + ) + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = object.__new__(NixlConnectorWorker) + worker._has_mamba = True + worker._physical_blocks_per_logical_kv_block = local_physical_per_logical + worker._group_spec_types = (FullAttentionSpec, MambaSpec) + worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True) + + aligned_local, aligned_remote = worker._apply_prefix_caching( + local_block_ids, remote_block_ids, remote_physical_per_logical + ) + + assert aligned_local == expected_local, ( + f"Expected local {expected_local}, got {aligned_local}" + ) + assert aligned_remote == expected_remote, ( + f"Expected remote {expected_remote}, got {aligned_remote}" + ) + + @pytest.mark.cpu_test @pytest.mark.parametrize( "local_physical_per_logical,remote_physical_per_logical," diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 0d30d4a692ad..a297058c845e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2333,9 +2333,25 @@ def _apply_prefix_caching( for i, remote_group in enumerate(remote_block_ids): num_local_blocks = len(local_block_ids[i]) num_remote_blocks = len(remote_group) - if _is_ssm_spec(self._group_spec_types[i]): - assert num_local_blocks == num_remote_blocks + if ( + _is_ssm_spec(self._group_spec_types[i]) + and num_local_blocks < num_remote_blocks + ): + # NOTE (NickLucche): With prefix caching on SSM, (remote) blocks + # prior to the last one are placeholders (null blocks). Mind that + # this doesn't really impact transfer, as we only still care about + # the last "block", the full in-place state. + assert num_local_blocks == 1, "SSM can only have one local block" + remote_block_ids[i] = remote_group[-num_local_blocks:] + elif ( + self._physical_blocks_per_logical_kv_block + == remote_physical_per_logical + and num_local_blocks < num_remote_blocks + ): + # Partial prefix cache hit for FA group. + remote_block_ids[i] = remote_group[-num_local_blocks:] else: + # TODO Handle prefix caching with different block_sizes max_padding = max( self._physical_blocks_per_logical_kv_block, remote_physical_per_logical, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f76ad05607e4..25ad4425e36e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -292,9 +292,6 @@ def _mamba_block_aligned_split( num_new_local_computed_tokens: int = 0, num_external_computed_tokens: int = 0, ) -> int: - assert num_external_computed_tokens == 0, ( - "External KV connector is not verified yet" - ) num_computed_tokens = ( request.num_computed_tokens + num_new_local_computed_tokens @@ -715,7 +712,8 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - if self.need_mamba_block_aligned_split: + # Skip block alignment when setting up async receive (no local work). + if self.need_mamba_block_aligned_split and not load_kv_async: num_new_tokens = self._mamba_block_aligned_split( request, num_new_tokens,