From 6e39b539c7c3f6ae21722c7c8c384cc32489da36 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 12 May 2026 13:59:15 +0000 Subject: [PATCH 1/4] prefix caching for matching block_size Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 11 +++++++++-- vllm/v1/core/sched/scheduler.py | 6 ++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index ea8b46c28f9c..9e4224c3f372 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2333,8 +2333,15 @@ def _apply_prefix_caching( for i, remote_group in enumerate(remote_block_ids): num_local_blocks = len(local_block_ids[i]) num_remote_blocks = len(remote_group) - if _is_ssm_spec(self._group_spec_types[i]): - assert num_local_blocks == num_remote_blocks + if ( + _is_ssm_spec(self._group_spec_types[i]) + and num_local_blocks < num_remote_blocks + ): + # NOTE (NickLucche): With prefix caching on SSM, (remote) blocks + # prior to the last one are placeholders (null blocks). We only + # care about the last one, which maintains the full state in-place. + assert num_local_blocks == 1, "SSM can only have one local block" + remote_block_ids[i] = remote_group[-num_local_blocks:] else: max_padding = max( self._physical_blocks_per_logical_kv_block, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 45a5169a3efe..5e04ea4c6e7d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -283,9 +283,6 @@ def _mamba_block_aligned_split( num_new_local_computed_tokens: int = 0, num_external_computed_tokens: int = 0, ) -> int: - assert num_external_computed_tokens == 0, ( - "External KV connector is not verified yet" - ) num_computed_tokens = ( request.num_computed_tokens + num_new_local_computed_tokens @@ -687,7 +684,8 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - if self.need_mamba_block_aligned_split: + # Skip block alignment when setting up async receive (no local work). + if self.need_mamba_block_aligned_split and not load_kv_async: num_new_tokens = self._mamba_block_aligned_split( request, num_new_tokens, From 683598eb08e8535199284ed1330d8037d4decca5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 13 May 2026 10:05:28 +0000 Subject: [PATCH 2/4] partial hit for FA Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 9e4224c3f372..e6e16e0c208a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2338,10 +2338,18 @@ def _apply_prefix_caching( and num_local_blocks < num_remote_blocks ): # NOTE (NickLucche): With prefix caching on SSM, (remote) blocks - # prior to the last one are placeholders (null blocks). We only - # care about the last one, which maintains the full state in-place. + # prior to the last one are placeholders (null blocks). Mind that + # this doesn't really impact transfer, as we only still care about + # the last "block", the full in-place state. assert num_local_blocks == 1, "SSM can only have one local block" remote_block_ids[i] = remote_group[-num_local_blocks:] + elif ( + self._physical_blocks_per_logical_kv_block + == remote_physical_per_logical + and num_local_blocks < num_remote_blocks + ): + # Partial prefix cache hit for FA group. + remote_block_ids[i] = remote_group[-num_local_blocks:] else: max_padding = max( self._physical_blocks_per_logical_kv_block, From 0871fba3b40fd107d38c3ab2292644ad94f14540 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 13 May 2026 17:47:06 +0000 Subject: [PATCH 3/4] test Signed-off-by: NickLucche --- .../unit/test_nixl_connector_hma.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 9c163fdf327f..3e2b7ae9e46d 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid( ) +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "local_physical_per_logical,remote_physical_per_logical," + "local_block_ids,remote_block_ids," + "expected_local,expected_remote", + [ + # SSM prefix caching: remote has 3 placeholder + 1 real block, + # local has only the 1 real block. FA blocks are equal (no trim). + pytest.param( + 10, + 10, + [list(range(10)), [42]], + [list(range(10)), [40, 41, 42, 43]], + [list(range(10)), [42]], + [list(range(10)), [43]], + id="ssm_prefix_trim_only", + ), + # FA partial prefix cache hit with homogeneous TP: local has 4 FA + # blocks (prefix cached), remote has full 10. SSM equal (no trim). + pytest.param( + 10, + 10, + [list(range(6, 10)), [42]], + [list(range(10)), [42]], + [list(range(6, 10)), [42]], + [list(range(6, 10)), [42]], + id="fa_prefix_hit_homo_tp", + ), + # Both: FA partial prefix hit + SSM placeholder trim. + # local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9] + # local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real) + pytest.param( + 10, + 10, + [[6, 7, 8, 9], [99]], + [list(range(10)), [10, 20, 99]], + [[6, 7, 8, 9], [99]], + [[6, 7, 8, 9], [99]], + id="fa_prefix_hit_and_ssm_trim", + ), + ], +) +def test_apply_prefix_caching_ssm_prefix_cache_hit( + local_physical_per_logical, + remote_physical_per_logical, + local_block_ids, + remote_block_ids, + expected_local, + expected_remote, +): + """_apply_prefix_caching end-trims SSM remote blocks to match the single + local block (placeholders dropped) and end-trims FA remote blocks on + partial prefix cache hits when physical_per_logical matches. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, + ) + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = object.__new__(NixlConnectorWorker) + worker._has_mamba = True + worker._physical_blocks_per_logical_kv_block = local_physical_per_logical + worker._group_spec_types = (FullAttentionSpec, MambaSpec) + worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True) + + aligned_local, aligned_remote = worker._apply_prefix_caching( + local_block_ids, remote_block_ids, remote_physical_per_logical + ) + + assert aligned_local == expected_local, ( + f"Expected local {expected_local}, got {aligned_local}" + ) + assert aligned_remote == expected_remote, ( + f"Expected remote {expected_remote}, got {aligned_remote}" + ) + + @pytest.mark.cpu_test @pytest.mark.parametrize( "local_physical_per_logical,remote_physical_per_logical," From 68dc38bcbac5004090939bbeb6bdcb9574379bb0 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 20 May 2026 15:13:26 +0200 Subject: [PATCH 4/4] comment Signed-off-by: NickLucche --- vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index e6e16e0c208a..55081a603b4e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2351,6 +2351,7 @@ def _apply_prefix_caching( # Partial prefix cache hit for FA group. remote_block_ids[i] = remote_group[-num_local_blocks:] else: + # TODO Handle prefix caching with different block_sizes max_padding = max( self._physical_blocks_per_logical_kv_block, remote_physical_per_logical,