diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 684e2ec4d7b9..245b5473448a 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -18,11 +18,19 @@ dp_ep_configs=( "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) ) +hybrid_ssm_configs=( + "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code" + # TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models. + "ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling" +) # Select config array based on DP_EP env var if [[ -n "${DP_EP:-}" ]]; then configs=("${dp_ep_configs[@]}") echo "DP_EP is set, using dp_ep_configs" +elif [[ -n "${HYBRID_SSM:-}" ]]; then + configs=("${hybrid_ssm_configs[@]}") + echo "HYBRID_SSM is set, using hybrid_ssm_configs." else configs=("${tp_configs[@]}") fi diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index 674e65c25ef4..a7fea4e630c9 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -18,6 +18,7 @@ "deepseek-ai/deepseek-vl2-tiny": 0.19, "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65, "google/gemma-3-4b-it": 0.74, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": 0.84, } SIMPLE_PROMPT = ( diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 10fa4f14f237..c46fbbd38513 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -53,7 +53,13 @@ from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheTensor, +) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import RequestStatus from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init): # Prefill connector will register KV cache to populate proper handshake # metadata. - # TODO this must match with values used in kv cache config - kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) + kv_cache_groups = [ + KVCacheGroupSpec( + ["layer0", "layer1", "layer2"], + FullAttentionSpec( + block_size=16, + num_kv_heads=4, + head_size=16, + dtype=torch.float16, + ), + ) + ] + kv_cache_config = KVCacheConfig( + num_blocks=2, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups + ) prefill_connector = NixlConnector( vllm_config, KVConnectorRole.WORKER, kv_cache_config ) @@ -437,7 +455,7 @@ def __init__( self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. self.src_xfer_handles_by_block_size = {self.block_size: 1} - test_shape = self.attn_backend.get_kv_cache_shape( + test_shape = self.attn_backends[0].get_kv_cache_shape( num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) self.kv_topo = TpKVTopology( @@ -447,7 +465,7 @@ def __init__( remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=self.attn_backend, + attn_backends=self.attn_backends, tensor_shape=test_shape, ) @@ -501,6 +519,7 @@ def _nixl_handshake( # is started. We mock HND here. kv_cache_layout="HND", block_size=self.block_size, + ssm_sizes=(0, 0), ), remote_tp_rank=remote_tp_rank, remote_tp_size=remote_tp_size, @@ -951,6 +970,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( block_lens=worker.block_len_per_layer, kv_cache_layout=mismatched_layout, block_size=worker.block_size, + ssm_sizes=(0, 0), ) with pytest.raises(RuntimeError): @@ -1006,6 +1026,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( block_lens=[i * 2 for i in worker.block_len_per_layer], kv_cache_layout="HND", block_size=worker.block_size, + ssm_sizes=(0, 0), ) # We don't check layout for homogeneous TP and MLA for now, as the @@ -1496,9 +1517,47 @@ def test_register_kv_caches( # test run if not mocking. mock_get_attn_backend.return_value = backend_cls mock_get_attn_backends.return_value = [backend_cls] - + num_layers = 32 + block_size = 16 + num_blocks = 8 + num_heads = 4 + head_size = 16 + + # TODO (NickLucche) the fact that connector depends on kv_cache_config for init + # but cross-layer preference cant be inferred prior to creating kv_cache_config + # is a bit awkward. + dummy_connector = NixlConnector( + vllm_config, + KVConnectorRole.WORKER, + make_kv_cache_config(block_size=block_size), + ) + kv_cache_spec = FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_heads, + head_size=head_size, + dtype=torch.float16, + ) + if dummy_connector.prefer_cross_layer_blocks: + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=kv_cache_spec.page_size_bytes * num_blocks, + shared_by=["all-layers"], + ) + for _ in range(num_layers) + ], + kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)], + ) + else: + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec) + ], + ) # Create connector - kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, @@ -1526,35 +1585,6 @@ def test_register_kv_caches( or connector.prefer_cross_layer_blocks ) if connector.prefer_cross_layer_blocks: - num_layers = 32 - block_size = 16 - num_blocks = 8 - # Keep the fake worker's expected num_blocks in sync with the - # cross-layer tensor we are about to register. - worker_kv_cache_config = make_kv_cache_config( - block_size=block_size, num_blocks=num_blocks - ) - connector.connector_worker.kv_cache_config = worker_kv_cache_config - connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks - kv_cache_spec = AttentionSpec( - block_size=block_size, - num_kv_heads=4, - head_size=64, - dtype=torch.bfloat16, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=[ - KVCacheTensor( - size=kv_cache_spec.page_size_bytes * num_blocks, - shared_by=["dummy-layer"], - ) - for i in range(num_layers) - ], - # allocate_uniform_kv_caches does not use this - kv_cache_groups=[], - ) - with set_current_vllm_config(vllm_config): _, cross_layers_kv_cache, _ = ( KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( @@ -1586,12 +1616,8 @@ def test_register_kv_caches( expected_blocks_count = 8 kv_caches = {"all-layers": cross_layers_kv_cache} - else: # Create test kv cache tensors using proper backend shape - kv_cache_spec = cast( - AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec - ) kv_cache_shape = backend_cls.get_kv_cache_shape( num_blocks=kv_cache_config.num_blocks, block_size=kv_cache_spec.block_size, @@ -2261,7 +2287,7 @@ def test_compatibility_hash_validation( kv_cache_spec = cast( AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec ) - kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( + kv_cache_shape = decode_worker.attn_backends[0].get_kv_cache_shape( num_blocks=kv_cache_config.num_blocks, block_size=kv_cache_spec.block_size, num_kv_heads=kv_cache_spec.num_kv_heads, @@ -2269,10 +2295,14 @@ def test_compatibility_hash_validation( ) shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) + # Build kv_caches from the actual layer names in kv_cache_config so that + # _layer_specs lookups in register_kv_caches always find a matching key. + layer_names = [ + name for group in kv_cache_config.kv_cache_groups for name in group.layer_names + ] kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, + name: shared_tensor if i % 2 == 0 else unique_tensor + for i, name in enumerate(layer_names) } decode_connector.register_kv_caches(kv_caches) @@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation( block_lens=[4096 * prefill_block_size], # slot_size * block_size kv_cache_layout="HND", block_size=prefill_block_size, + ssm_sizes=(0, 0), ) handshake_payload = NixlHandshakePayload( compatibility_hash=remote_hash, @@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) remote_block_size=decode_worker._block_size, # shared state is_mla=decode_worker.use_mla, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), - attn_backend=backend, + attn_backends=[backend], tensor_shape=test_shape, ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 636d51402bde..d4b0c28a5de5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma(): # Simulate HMA scenario: logical block size = 32, kernel block size = 16 # So each logical block maps to 2 kernel blocks eg [0]->[0,1] worker._physical_blocks_per_logical_kv_block = 2 + # FA + SW groups (neither is MambaSpec, so both get expanded) + worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True) # Test conversion: FA + SW group logical_block_ids = [[0, 1, 2], [3, 4]] @@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure(): assert len(req_meta.remote.block_ids) == 2 assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21] + + +@pytest.mark.cpu_test +def test_get_block_descs_ids_hybrid_ssm(): + """Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM + when ratio=1 (no kernel block size mismatch).""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorWorker, + ) + + worker = object.__new__(NixlConnectorWorker) + + num_blocks = 100 + engine_id = "test-engine" + worker.num_regions = 2 + worker.dst_num_blocks = {engine_id: num_blocks} + worker._has_mamba = True + worker._is_mamba_group = [False, True] + worker._physical_blocks_per_logical_kv_block = 1 + # num_descs = num_regions * num_blocks (no blocks_first doubling) + worker.num_descs = 2 * num_blocks + + fa_blocks = [3, 5] + ssm_blocks = [1, 2] + result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) + + # FA group: stride=num_blocks=100, offset=0 + # region0: [3, 5], region1: [103, 105] + # SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1), + # offset=num_descs=200 + # region0: [201, 202], region1: [301, 302] + expected = [3, 5, 103, 105, 201, 202, 301, 302] + assert list(result) == expected, f"Expected {expected}, got {list(result)}" + + +@pytest.mark.cpu_test +def test_get_block_descs_ids_kernel_block_mismatch(): + """Test _get_block_descs_ids uses different strides for FA (kernel blocks) + vs SSM (logical blocks) when ratio > 1.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorWorker, + ) + + worker = object.__new__(NixlConnectorWorker) + + ratio = 4 + logical_blocks = 100 + num_blocks = logical_blocks * ratio # 400 kernel blocks + engine_id = "test-engine" + worker.num_regions = 2 + worker.dst_num_blocks = {engine_id: num_blocks} + worker._has_mamba = True + worker._is_mamba_group = [False, True] + worker._physical_blocks_per_logical_kv_block = ratio + worker.num_descs = 2 * num_blocks # 800 + + fa_blocks = [3, 7] # kernel-level block IDs + ssm_blocks = [1, 2] # logical block IDs + result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) + + # FA group: stride=num_blocks=400, offset=0 + # region0: [3, 7], region1: [403, 407] + # SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800 + # region0: [801, 802], region1: [901, 902] + expected = [3, 7, 403, 407, 801, 802, 901, 902] + assert list(result) == expected, f"Expected {expected}, got {list(result)}" + + +@pytest.mark.cpu_test +def test_nixl_metadata_hybrid_ssm_block_ids(): + """Test NixlConnectorMetadata correctly stores block IDs for FA + SSM + groups with different block counts (kernel mismatch active).""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata, + ) + + metadata = NixlConnectorMetadata() + + # FA: 8 kernel blocks (2 logical * ratio=4), SSM: 2 logical blocks + fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7] + ssm_blocks = [0, 1] + + metadata.add_new_req_to_recv( + request_id="test-req-hybrid", + local_block_ids=(fa_blocks, ssm_blocks), + kv_transfer_params={ + "remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [20, 21]), + "remote_engine_id": "remote-engine", + "remote_request_id": "prefill-test-req-hybrid", + "remote_host": "localhost", + "remote_port": 1234, + "tp_size": 1, + }, + ) + + assert "test-req-hybrid" in metadata.reqs_to_recv + req_meta = metadata.reqs_to_recv["test-req-hybrid"] + + # Verify local block IDs: different lengths per group + assert len(req_meta.local_block_ids) == 2 + assert list(req_meta.local_block_ids[0]) == fa_blocks + assert list(req_meta.local_block_ids[1]) == ssm_blocks + assert len(req_meta.local_block_ids[0]) != len(req_meta.local_block_ids[1]) + + # Verify remote block IDs: same asymmetry preserved + assert req_meta.remote is not None + assert len(req_meta.remote.block_ids) == 2 + assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] + assert list(req_meta.remote.block_ids[1]) == [20, 21] + assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1]) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 155395e84e11..1f889c6c838a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -16,10 +16,12 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase + from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -328,22 +330,26 @@ class TpKVTopology: remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int - attn_backend: type[AttentionBackend] + attn_backends: list[type[AttentionBackend]] engine_id: EngineId remote_block_size: dict[EngineId, int] tensor_shape: torch.Size | None = None + is_mamba: bool = False def __post_init__(self): # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. - _MOCK_BLOCK_SIZE = 16 - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 - ) - logger.debug("Test kv_cache_shape: %s", kv_cache_shape) + attn_backend = self.attn_backends[0] + if not self.is_mamba: + _MOCK_BLOCK_SIZE = 16 + kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 + ) + logger.debug("Test kv_cache_shape: %s", kv_cache_shape) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below. - self._is_kv_layout_blocks_first = ( + # Hybrid SSM models assume a single blocks_first layout + self._is_kv_layout_blocks_first = self.is_mamba or ( len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) @@ -360,7 +366,7 @@ def __post_init__(self): _MOCK_NUM_LAYERS = 80 kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=self._cross_layers_blocks ) except (AttributeError, NotImplementedError): @@ -483,6 +489,30 @@ def get_target_remote_ranks_from_engine_id( remote_tp_size = self.remote_tp_size[remote_engine_id] return self.get_target_remote_ranks(remote_tp_size) + def get_transfer_cache_regions( + self, cache: torch.Tensor, layer_spec: "KVCacheSpec" + ) -> list[torch.Tensor] | torch.Tensor: + """Return the cache tensor(s) to register as NIXL memory regions, + also accounting for hybrid SSM models specificities. + """ + if isinstance(layer_spec, MambaSpec): + # Register the whole kv cache shared tensor, including SSM/Conv. This is + # similar to FI with the difference that SSM/Conv have different sizes + conv, ssm = cache + return [conv] + + # Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`. + if self.is_mamba and cache.shape[0] == 2: + # When MAMBA is present, all backends are blocks first, so that blocks + # can be shared between attention layers and mamba layers. Runner + # `_update_hybrid_attention_mamba_layout` already adjusted strides + # for FlashAttn-like backends so its num_blocks first. + # Swap [2<>num_blocks] dims to get required layout for hybrid SSM. + cache = cache.transpose(0, 1) + + # Regular case: backends like FA register K/V in separate regions + return cache if self.split_k_and_v else [cache] + def get_current_attn_backends( vllm_config: VllmConfig, layer_names: list[str] | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index d986f686657f..28b997128d46 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -564,7 +564,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=backend, + attn_backends=[backend], ) self.async_zmq_ctx = zmq.asyncio.Context() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e6c49d7a025e..e1b908cdfb7d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -59,7 +59,12 @@ from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.utils import select_common_block_size @@ -159,6 +164,7 @@ class NixlAgentMetadata: block_lens: list[int] kv_cache_layout: str block_size: int + ssm_sizes: tuple[int, int] @dataclass @@ -310,6 +316,15 @@ def add_new_req_to_recv( class NixlConnector(KVConnectorBase_V1, SupportsHMA): @property def prefer_cross_layer_blocks(self) -> bool: + if any( + [ + isinstance(group.kv_cache_spec, MambaSpec) + for group in self.kv_cache_config.kv_cache_groups + ] + ): + # Hybrid SSM models do not yet support cross-layer layout + return False + backend = get_current_attn_backend(self._vllm_config) if backend.get_name() not in ( "FLASH_ATTN", @@ -335,12 +350,9 @@ def __init__( kv_cache_config: "KVCacheConfig", ): super().__init__(vllm_config, role, kv_cache_config) - assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None - for group in kv_cache_config.kv_cache_groups: - if isinstance(group.kv_cache_spec, MambaSpec): - raise ValueError("NixlConnector does not support Mamba models.") + self.kv_cache_config = kv_cache_config self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.kv_transfer_config = vllm_config.kv_transfer_config if role == KVConnectorRole.SCHEDULER: @@ -434,11 +446,7 @@ def register_cross_layers_kv_cache( self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] ): assert self.connector_worker is not None - - cross_layer_name = "ALL_LAYERS" - kv_caches = {cross_layer_name: kv_cache} - - self.connector_worker.register_kv_caches(kv_caches) + self.connector_worker.register_cross_layers_kv_caches(kv_cache) def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None @@ -962,6 +970,40 @@ def __init__( ) ) self.kv_cache_config = kv_cache_config + self._layer_specs = { + layer: group.kv_cache_spec + for group in kv_cache_config.kv_cache_groups + for layer in group.layer_names + } + self.hma_group_size = len(kv_cache_config.kv_cache_tensors) + + # Mamba metadata + self._is_mamba_group = [ + isinstance(group.kv_cache_spec, MambaSpec) + for group in kv_cache_config.kv_cache_groups + ] + mamba_ssm_size = (0, 0) + self._has_mamba = any(self._is_mamba_group) + if self._has_mamba: + assert self._is_hma_required + mamba_spec = next( + spec + for spec in self._layer_specs.values() + if isinstance(spec, MambaSpec) + ) + conv_nbytes, ssm_nbytes = ( + torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc] + torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc] + ) + conv_shape, ssm_shape = ( + torch.Size(mamba_spec.shapes[0]), + torch.Size(mamba_spec.shapes[1]), + ) + mamba_ssm_size = ( + conv_shape.numel() * conv_nbytes, + ssm_shape.numel() * ssm_nbytes, + ) + self._mamba_ssm_size = mamba_ssm_size # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] @@ -1109,9 +1151,9 @@ def __init__( # Get the attention backend from the first layer # NOTE (NickLucche) models with multiple backends are not supported yet - self.attn_backend = get_current_attn_backend(vllm_config) + self.attn_backends = get_current_attn_backends(vllm_config) + self.backend_name = self.attn_backends[0].get_name() - self.backend_name = self.attn_backend.get_name() self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.info("Detected attention backend %s", self.backend_name) @@ -1138,6 +1180,8 @@ def __init__( def _sync_block_size_with_kernel(self) -> None: backends = get_current_attn_backends(self.vllm_config) kernel_block_size = select_common_block_size(self.block_size, backends) + # Number of blocks not accounting for kernel block mismatches + self._logical_num_blocks = self.num_blocks if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" @@ -1431,9 +1475,19 @@ def request_ready(f: Future[Any], entry=(req_id, meta)): fut.add_done_callback(request_ready) + def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None: + """Register a cross-layers KV cache tensor with NIXL. + + `use_uniform_kv_cache()` guarantees a single KV cache group whose + layers all share the same `AttentionSpec`, so any layer name from + `_layer_specs` yields the correct per-layer spec for `page_size_bytes`. + """ + first_layer = next(iter(self._layer_specs)) + # Forwarding a real layer name rather than a synthetic key + self.register_kv_caches({first_layer: kv_cache}) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, engine_id=self.engine_id, @@ -1441,8 +1495,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=self.attn_backend, - tensor_shape=next(iter(kv_caches.values())).shape, + attn_backends=self.attn_backends, + # SSM States come in tuples (ssm, conv) + tensor_shape=next(iter(kv_caches.values())).shape + if not self._has_mamba + else None, + is_mamba=self._has_mamba, ) self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks @@ -1484,12 +1542,50 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # to better exploit the memory layout (ie num_blocks is the first dim). tensor_size_bytes = None - # Enable different block lengths for different layers when MLA is used. + # Enable different block lengths for different layers *only* when MLA is used. + # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. self.block_len_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): - cache_list = ( - cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches] + # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to + # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. + # However, physical page_size may differ when kernel requires a specific + # block size. This leads to SSM and FA layers having different num_blocks. + # `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this. + layer_spec = self._layer_specs[layer_name] + if isinstance(layer_spec, UniformTypeKVCacheSpecs): + # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs + layer_spec = layer_spec.kv_cache_specs[layer_name] + cache_list = self.kv_topo.get_transfer_cache_regions( + cache_or_caches, layer_spec + ) + # `layer_spec.page_size_bytes` only accounts for logical page_size, that is + # the page_size assuming constant `self._logical_num_blocks`. + physical_page_size = ( + layer_spec.page_size_bytes + if isinstance(layer_spec, MambaSpec) + else layer_spec.page_size_bytes + // self._physical_blocks_per_logical_kv_block ) + # For when registering multiple tensors eg K/V in separate regions. + physical_page_size = physical_page_size // len(cache_list) + if self.kv_topo._cross_layers_blocks: + # When cross-layers blocks are used, multiply by number of layers + physical_page_size = physical_page_size * len( + self.kv_cache_config.kv_cache_tensors + ) + num_blocks = ( + self._logical_num_blocks + if isinstance(layer_spec, MambaSpec) + else self.num_blocks + ) + # `page_size` accounts for physical blocks, st KVCache is always + # [`num_blocks` * `page_size`] + curr_tensor_size_bytes = num_blocks * physical_page_size + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + + # TODO (NickLucche) we could eventually unify how we handle FA/FI regions, + # registering a single tensor for both K/V and splitting logically like FI. for cache in cache_list: base_addr = cache.data_ptr() if base_addr in seen_base_addresses: @@ -1497,27 +1593,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # across groups. This results in skipping all tensors but the ones # pointed to by group0. Also, generally we will have more blocks # per tensor but fewer regions. + logger.debug("Skipping %s because it's already seen", layer_name) continue - logger.debug( "Registering layer %s with cache shape: %s", layer_name, cache.shape ) seen_base_addresses.append(base_addr) - curr_tensor_size_bytes = cache.numel() * cache.element_size() - - if tensor_size_bytes is None: - tensor_size_bytes = curr_tensor_size_bytes + # Only record non-Mamba page sizes. + if isinstance(layer_spec, MambaSpec): + self.block_len_per_layer.append( + physical_page_size // self._physical_blocks_per_logical_kv_block + ) + else: + self.block_len_per_layer.append(physical_page_size) - assert cache.shape[0] == self.num_blocks, ( + assert cache.shape[0] == num_blocks, ( "All kv cache tensors must have the same number of blocks" ) - self.block_len_per_layer.append( - curr_tensor_size_bytes // self.num_blocks - ) - if not self.use_mla: - # Different kv cache shape is not supported by HeteroTP + # Different kv cache shape is not supported by HeteroTP. + # This must also hold true for Mamba-like models. assert tensor_size_bytes == curr_tensor_size_bytes, ( "All kv cache tensors must have the same size" ) @@ -1536,6 +1632,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) + if self.kv_topo.is_kv_layout_blocks_first: + # NOTE (NickLucche) When FlashInfer is used, memory is registered + # with joint KV for each block. This minimizes the overhead in + # registerMem allowing faster descs queries. In order to be able to + # split on kv_heads dim as required by heterogeneous TP, one must + # be able to index K/V separately. Hence we double the number + # of 'virtual' regions here and halve `block_len` below. + # Similarly for Mamba layers, we register SSM+Conv as a single region and + # then duplicate it logically to be able to index SSM/Conv separately. + self.num_regions *= 2 + + # TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to + # support heterogeneous TP. + self.num_descs = self.num_regions * self.num_blocks + descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends) @@ -1545,17 +1656,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self.kv_topo.is_kv_layout_blocks_first: - # NOTE (NickLucche) When FlashInfer is used, memory is registered - # with joint KV for each block. This minimizes the overhead in - # registerMem allowing faster descs queries. In order to be able to - # split on kv_heads dim as required by heterogeneous TP, one must - # be able to index K/V separately. Hence we double the number - # of 'virtual' regions here and halve `block_len` below. - self.num_regions *= 2 + if self._has_mamba: + logger.info( + "Hybrid SSM registration: num_blocks=%s, " + "logical_num_blocks=%s, ratio=%s, num_regions=%s, " + "num_descs=%s, mamba_ssm_size=%s, block_len_per_layer=%s", + self.num_blocks, + self._logical_num_blocks, + self._physical_blocks_per_logical_kv_block, + self.num_regions, + self.num_descs, + self._mamba_ssm_size, + set(self.block_len_per_layer), + ) # Register local/src descr for NIXL xfer. - self.seen_base_addresses = seen_base_addresses self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = ( self.register_local_xfer_handler(self.block_size) ) @@ -1572,6 +1687,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, + ssm_sizes=self._mamba_ssm_size, ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1597,40 +1713,65 @@ def register_local_xfer_handler( data copy correctness. """ assert self.kv_topo is not None + kv_topo = self.kv_topo block_size_ratio = self.block_size // block_size - blocks_data = [] - for i, base_addr in enumerate(self.seen_base_addresses): - # The new block_len is using prefill block_len; - # and num_blocks is multiple with N - kv_block_len = ( - self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio - ) - block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio - num_blocks = self.num_blocks * block_size_ratio - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) - - if self.kv_topo.is_kv_layout_blocks_first: - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. + blocks_data: list[tuple[int, int, int]] = [] + local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] + + def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool): + for i, base_addr in enumerate(local_base_addresses): + # The new block_len is using prefill block_len; + # and num_blocks is multiple with N + kv_block_len = ( + self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=mamba + ) + // block_size_ratio + ) + # Jump one page_size, but ssm page_size may be bigger when kernel + # locks block size to a specific value. + block_len_per_layer = ( + self.block_len_per_layer[i] + // block_size_ratio + * (1 if not mamba else self._physical_blocks_per_logical_kv_block) + ) + num_blocks = self._logical_num_blocks if mamba else self.num_blocks + num_blocks = num_blocks * block_size_ratio for block_id in range(num_blocks): block_offset = block_id * block_len_per_layer addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, self.device_id)) + + if kv_topo.is_kv_layout_blocks_first: + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=mamba + ) + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, second_split, self.device_id)) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) + + register_blocks(blocks_data, mamba=False) + if self._has_mamba: + assert self.num_descs == len(blocks_data) + logger.debug( + "Registering additional %s local Mamba blocks", len(blocks_data) + ) + register_blocks(blocks_data, mamba=True) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -1711,7 +1852,8 @@ def add_remote_agent( # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| assert self.kv_topo is not None - block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) + kv_topo = self.kv_topo + block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id) if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1771,48 +1913,86 @@ def add_remote_agent( # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Register all remote blocks, but only the corresponding kv heads. - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # Read our whole local region size from remote. - local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len + def register_remote_blocks( + blocks_data: list[tuple[int, int, int]], mamba: bool + ): + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + # Read our whole local region size from remote. + local_block_len = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=mamba + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + # using remote kv_block_len as transfer unit + local_block_len = remote_kv_block_len + + if tp_ratio < 0 and not self.use_mla: + # Remote tp is bigger: read a chunk of local region from remote + local_block_len = local_block_len // (-tp_ratio) + rank_offset = ( + self.tp_rank % tp_ratio * remote_kv_block_len + if indexes_into_remote + else 0 + ) - if tp_ratio < 0 and not self.use_mla: - # Remote tp is bigger: read a chunk of local region from remote - local_block_len = local_block_len // (-tp_ratio) - rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len - if indexes_into_remote - else 0 - ) - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_lens[i] - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id)) - - if self.kv_topo.is_kv_layout_blocks_first: - # With FlashInfer index V separately to allow head splitting. - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_lens[i] + # Assume same num_blocks for mamba and fa + num_blocks = ( + nixl_agent_meta.num_blocks + if not mamba + else nixl_agent_meta.num_blocks + // self._physical_blocks_per_logical_kv_block + ) + page_size = nixl_agent_meta.block_lens[i] * ( + 1 if not mamba else self._physical_blocks_per_logical_kv_block + ) + for block_id in range(num_blocks): + block_offset = block_id * page_size + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + # (addr, len, device id) blocks_data.append( - (v_addr, local_block_len, nixl_agent_meta.device_id) + (addr, local_block_len, nixl_agent_meta.device_id) ) - logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and local rank %s", - len(blocks_data), - engine_id, - remote_tp_rank, - self.tp_rank, - ) + if kv_topo.is_kv_layout_blocks_first: + # With FlashInfer index V separately to allow head splitting. + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=mamba + ) + # Apply the same scaling as local_block_len above for when we read + # a chunk of local V from `tp_ratio` separate remote workers. + if tp_ratio < 0 and not self.use_mla: + second_split = second_split // (-tp_ratio) + for block_id in range(num_blocks): + block_offset = block_id * page_size + addr = base_addr + block_offset + rank_offset + # Hop over the first split of remote page: either K or Conv. + if mamba: + v_addr = addr + nixl_agent_meta.ssm_sizes[0] + else: + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + blocks_data.append( + (v_addr, second_split, nixl_agent_meta.device_id) + ) + + logger.debug( + "Created %s blocks for dst engine %s" + " with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) + + register_remote_blocks(blocks_data, mamba=False) + if self._has_mamba: + # Create extra descs for the Mamba "view" of the same KV cache tensors. + logger.debug( + "Registering additional %s remote Mamba blocks", len(blocks_data) + ) + register_remote_blocks(blocks_data, mamba=True) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) @@ -1852,6 +2032,9 @@ def _validate_remote_agent_handshake( assert block_size_ratio == 1, ( "HMA does not support different remote block size yet" ) + # Mamba additional constraints + if self._has_mamba: + assert tp_ratio == 1, "Mamba does not support heterogeneous TP yet" kv_cache_layout = ( self.kv_cache_layout @@ -2498,6 +2681,7 @@ def _get_block_descs_ids( A single flattened array is returned for all groups anyway. """ region_ids = np.arange(self.num_regions) + # NOTE (NickLucche) With HMA, every kv group has the same number of layers and # layers from different groups share the same kv tensor. # eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions, @@ -2508,11 +2692,33 @@ def _get_block_descs_ids( if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) - # Compute the desc ids for each block. + # Compute desc ids per group using the right stride: FA descs have + # num_blocks entries per region (kernel granularity), SSM descs have + # logical_blocks entries per region (no kernel splitting). region_ids = region_ids[:, None] - block_ids = np.concatenate(block_ids)[None, :] - descs_ids = region_ids * num_blocks + block_ids - return descs_ids.flatten() + if not self._has_mamba: + block_ids = np.concatenate(block_ids)[None, :] + descs_ids = region_ids * num_blocks + block_ids + return descs_ids.flatten() + else: + # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged + # arbitrarily by manager. Therefore, descs are duplicated for SSM and + # Attention like so: + # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. + # This is like having two "low-level views" of the same storage. + # `num_fa_descs` offset must be computed per-engine since P and D can + # have different num_blocks (and thus different FA descs counts). + ratio = self._physical_blocks_per_logical_kv_block + # SSM may register fewer num_blocks than FA + logical_blocks = num_blocks // ratio + num_fa_descs = self.num_regions * num_blocks + all_descs = [] + for i, group in enumerate(block_ids): + stride = logical_blocks if self._is_mamba_group[i] else num_blocks + group_arr = np.asarray(group)[None, :] + offset = num_fa_descs if self._is_mamba_group[i] else 0 + all_descs.append((region_ids * stride + group_arr + offset).flatten()) + return np.concatenate(all_descs) def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: """ @@ -2526,16 +2732,22 @@ def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( 1, -1 ) + # Mamba blocks have no logical<>physical discrepancy + group_specs = self.kv_cache_config.kv_cache_groups return [ BlockTable.map_to_kernel_blocks( np.array(group), self._physical_blocks_per_logical_kv_block, block_arange, ).tolist() - for group in block_ids + if not isinstance(group_specs[i].kv_cache_spec, MambaSpec) + else group + for i, group in enumerate(block_ids) ] - def get_backend_aware_kv_block_len(self, layer_idx: int) -> int: + def get_backend_aware_kv_block_len( + self, layer_idx: int, first_split: bool = True, mamba_view: bool = False + ) -> int: """ Get the block length for one K/V element (K and V have the same size). @@ -2543,11 +2755,38 @@ def get_backend_aware_kv_block_len(self, layer_idx: int) -> int: block, as K and V are in separate regions. For FlashInfer, this is half the length of the whole block, as K and V share the same region. + Similarly, for SSM-based models, state and conv are interleaved, but crucially + the their size differs. + Reference diagram: + KVCacheTensor (Shared) + / \ + / \ + / \ + Attention (FlashInfer) View Mamba View + | | + | | + +-------------------+ +-------------------+ + | KVCacheTensor | | KVCacheTensor | + | | | | + |<----- page ------>| |<----- page ------->| + | size | | size | + | Key 0 | Val 0 | |Conv 0 | SSM 0 | + | Key 1 | Val 1 | |Conv 1 | SSM 1 | + | ... | ... | | ... | ... | + | Key N-2 | Val N-2 | |Conv N-2| SSM N-2 | + | Key N-1 | Val N-1 | |Conv N-1| SSM N-1 | + +-------------------+ +--------------------+ + |1st_split-2nd_split| |1st_split-2nd_split | """ assert self.kv_topo is not None if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). - block_len = self.block_len_per_layer[layer_idx] // 2 + if mamba_view: + # NOTE (NickLucche) Mamba Opt: this is already skipping the padding so + # we're only transferring the minimum required bytes. + block_len = self._mamba_ssm_size[not first_split] + else: + block_len = self.block_len_per_layer[layer_idx] // 2 else: block_len = self.block_len_per_layer[layer_idx] return block_len