diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 749c6fbe7f6d..af38087e4b3d 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -184,15 +184,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con --kv-transfer-config '{..., "enable_permute_local_kv":"True"}' ``` -### Cross layers blocks - -By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. -To enable this feature: - -```bash ---kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}' -``` - ## Example Scripts/Code Refer to these example scripts in the vLLM repository: diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index e5e333f19fe8..c2c38f51c500 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -34,18 +34,11 @@ else KV_CONFIG_HETERO_LAYOUT='' fi -CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers -if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then - KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}' -else - KV_EXTRA_CONFIG='' -fi - # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" fi # Models to run diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e93835598a41..a3fe2e21b56e 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -18,12 +18,8 @@ import torch from vllm import LLM -from vllm.config import KVTransferConfig, set_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.utils import ( - KVOutputAggregator, - TpKVTopology, - get_current_attn_backend, -) +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( @@ -52,11 +48,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import RequestStatus -from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin -from vllm.v1.worker.utils import AttentionGroup from .utils import create_request, create_scheduler, create_vllm_config @@ -373,7 +366,6 @@ def test_kv_transfer_handshake(dist_init): # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - decode_connector.register_kv_caches(kv_caches) # Here we are testing the retrieval of NIXLAgentMetadata. # Knowing the implementation detail, we override the add_remote_agent @@ -410,23 +402,6 @@ def __init__( self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. self.src_xfer_handles_by_block_size = {self.block_size: 1} - test_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - self.kv_topo = TpKVTopology( - tp_rank=self.tp_rank, - engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=self.attn_backend, - tensor_shape=test_shape, - ) - - self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks - ) def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -1395,7 +1370,6 @@ def req_id(outputs: list[RequestOutput]) -> str: ), ), "TRITON_ATTN", - "FLASHINFER", ], ) def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): @@ -1412,11 +1386,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): vllm_config = create_vllm_config(attention_backend=attn_backend) - # Enable cross layers blocks - vllm_config.kv_transfer_config.kv_connector_extra_config[ - "enable_cross_layers_blocks" - ] = True - # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend @@ -1426,11 +1395,49 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend backend_cls = RocmAttentionBackend - else: # TRITON + else: # TRITON_ATTN from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend backend_cls = TritonAttentionBackend + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + + test_shape = backend_cls.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 + + if is_blocks_first: + expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel() + expected_base_addrs = [ + shared_tensor.data_ptr(), + unique_tensor.data_ptr(), + ] + expected_num_entries = 2 + else: + expected_tensor_size = ( + shared_tensor[0].element_size() * shared_tensor[0].numel() + ) + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + ] + expected_num_entries = 4 + nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" with ( patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, @@ -1459,107 +1466,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): # Reassure the shutdown() check that the thread is terminated mock_thread.return_value.is_alive.return_value = False - expected_tensor_size: int - expected_base_addrs: list[int] - expected_num_entries: int - kv_caches: dict[str, torch.Tensor] - if connector.prefer_cross_layer_blocks: - num_layers = 32 - block_size = 16 - num_blocks = 8 - kv_cache_spec = AttentionSpec( - block_size=block_size, - num_kv_heads=4, - head_size=64, - dtype=torch.bfloat16, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=[ - KVCacheTensor( - size=kv_cache_spec.page_size_bytes * num_blocks, - shared_by=["dummy-layer"], - ) - for i in range(num_layers) - ], - # allocate_uniform_kv_caches does not use this - kv_cache_groups=[], - ) - - with set_current_vllm_config(vllm_config): - _, cross_layers_kv_cache, _ = ( - KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( - kv_cache_config=kv_cache_config, - attn_groups=[ - [ - AttentionGroup( - backend=backend_cls, - layer_names=[], - kv_cache_spec=kv_cache_spec, - kv_cache_group_id=0, - ) - ] - ], - cache_dtype=torch.bfloat16, - device=torch.cuda.current_device(), - kernel_block_sizes=[block_size], - ) - ) - # Store tensor info for validation - expected_tensor_size = ( - cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel() - ) - expected_base_addrs = [ - cross_layers_kv_cache.data_ptr(), - ] - expected_num_entries = 1 - - expected_blocks_count = 8 - - kv_caches = {"all-layers": cross_layers_kv_cache} - - else: - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } - - # Store tensor info for validation - - test_shape = backend_cls.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 - - if is_blocks_first: - expected_tensor_size = ( - shared_tensor.element_size() * shared_tensor.numel() - ) - expected_base_addrs = [ - shared_tensor.data_ptr(), - unique_tensor.data_ptr(), - ] - expected_num_entries = 2 - else: - expected_tensor_size = ( - shared_tensor[0].element_size() * shared_tensor[0].numel() - ) - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - ] - expected_num_entries = 4 - expected_blocks_count = 8 - # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -1583,19 +1489,16 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] # Validate blocks_data structure and size + expected_blocks_count = 8 assert len(blocks_data) == expected_blocks_count, ( f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) - if connector.prefer_cross_layer_blocks: - num_blocks = 8 - expected_block_len = expected_tensor_size // num_blocks + num_blocks = 2 + if is_blocks_first: + expected_block_len = expected_tensor_size // num_blocks // 2 else: - num_blocks = 2 - if is_blocks_first: - expected_block_len = expected_tensor_size // num_blocks // 2 - else: - expected_block_len = expected_tensor_size // num_blocks + expected_block_len = expected_tensor_size // num_blocks for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry @@ -2146,17 +2049,6 @@ def test_compatibility_hash_validation( ) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker - kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } - decode_connector.register_kv_caches(kv_caches) remote_config_params: dict[str, Any] = { "model": "facebook/opt-125m", @@ -2179,9 +2071,7 @@ def test_compatibility_hash_validation( ) ) remote_hash = compute_nixl_compatibility_hash( - remote_vllm_config, - decode_worker.backend_name, - decode_worker.kv_topo.cross_layers_blocks, + remote_vllm_config, decode_worker.backend_name ) prefill_block_size = config_overrides.get("block_size", 16) @@ -2260,27 +2150,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker - backend = get_current_attn_backend(local_vllm_config) - test_shape = backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - decode_worker.kv_topo = TpKVTopology( - tp_rank=decode_worker.tp_rank, - engine_id=decode_worker.engine_id, - remote_tp_size=decode_worker._tp_size, # shared state - remote_block_size=decode_worker._block_size, # shared state - is_mla=decode_worker.use_mla, - total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), - attn_backend=backend, - tensor_shape=test_shape, - ) - - decode_worker.compat_hash = compute_nixl_compatibility_hash( - decode_worker.vllm_config, - decode_worker.backend_name, - decode_worker.kv_topo.cross_layers_blocks, - ) - if error_scenario == "handshake_decode_error": msg_bytes = b"this is not valid msgpack data" elif error_scenario == "handshake_validation_error": diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b184f6574130..fd833e293938 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -316,7 +316,6 @@ class TpKVTopology: attn_backend: type[AttentionBackend] engine_id: EngineId remote_block_size: dict[EngineId, int] - tensor_shape: torch.Size | None = None def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -330,32 +329,6 @@ def __post_init__(self): len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) - self._kv_heads_position: int | None = None - self._cross_layers_blocks = False - if self.tensor_shape is not None: - self._cross_layers_blocks = ( - len(self.tensor_shape) == len(kv_cache_shape) + 1 - ) - - if self._cross_layers_blocks: - # prepend layers dimension - kv_cache_shape = (80,) + kv_cache_shape - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=self._cross_layers_blocks - ) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(self.tensor_shape))) - - # permute kv_cache_shape according to stride_order - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - - physical_block_size_position = kv_cache_shape.index(16) - assert physical_block_size_position is not None - self._physical_block_size_position = -( - len(kv_cache_shape) - physical_block_size_position - ) - @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -363,9 +336,7 @@ def is_kv_layout_blocks_first(self) -> bool: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not ( - self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first - ) + return not (self.is_mla or self.is_kv_layout_blocks_first) @property def tp_size(self) -> int: @@ -375,14 +346,6 @@ def tp_size(self) -> int: def block_size(self) -> int: return self.remote_block_size[self.engine_id] - @property - def cross_layers_blocks(self) -> bool: - return self._cross_layers_blocks - - @property - def block_size_position(self) -> int: - return self._physical_block_size_position - def tp_ratio( self, remote_tp_size: int, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index aef2848dcb9a..8e0651053e07 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -54,7 +54,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket -from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata +from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.block_table import BlockTable @@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): def compute_nixl_compatibility_hash( - vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool + vllm_config: VllmConfig, attn_backend_name: str ) -> str: """ Compute compatibility hash for NIXL KV transfer. @@ -216,7 +216,6 @@ def compute_nixl_compatibility_hash( # Attention backend and KV cache dtype affect memory layout "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), - "cross_layers_blocks": cross_layers_blocks, } compat_hash = hash_factors(factors) @@ -299,20 +298,6 @@ def add_new_req_to_recv( class NixlConnector(KVConnectorBase_V1): - @property - def prefer_cross_layer_blocks(self) -> bool: - backend = get_current_attn_backend(self._vllm_config) - if backend.get_name() not in ( - "FLASH_ATTN", - "FLASHINFER", - ): - # For now there is no benefit to run cross layers when backend - # does not support on HND - return False - - extra_config = self.kv_transfer_config.kv_connector_extra_config - return bool(str(extra_config.get("enable_cross_layers_blocks", "False"))) - def __init__( self, vllm_config: VllmConfig, @@ -324,7 +309,6 @@ def __init__( assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id - self.kv_transfer_config = vllm_config.kv_transfer_config if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: NixlConnectorScheduler | None = ( @@ -411,16 +395,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] - ): - assert self.connector_worker is not None - - cross_layer_name = "ALL_LAYERS" - kv_caches = {cross_layer_name: kv_cache} - - self.connector_worker.register_kv_caches(kv_caches) - def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) @@ -1002,17 +976,20 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Get the attention backend from the first layer # NOTE (NickLucche) models with multiple backends are not supported yet - self.attn_backend = get_current_attn_backend(vllm_config) + backend = get_current_attn_backend(vllm_config) - self.backend_name = self.attn_backend.get_name() + self.backend_name = backend.get_name() self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) - # lazy initialized in register_kv_caches - self.compat_hash: str | None = None - self.kv_topo: TpKVTopology | None = None + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name + ) + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( + "enforce_handshake_compat", True + ) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} @@ -1021,11 +998,16 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self._physical_blocks_per_logical_kv_block = 1 - - self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( - "enforce_handshake_compat", True + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, ) + self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( self, @@ -1040,7 +1022,6 @@ def _nixl_handshake( # Regardless, only handshake with the remote TP rank(s) that current # local rank will read from. Note that With homogeneous TP, # this happens to be the same single rank_i. - assert self.kv_topo is not None p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) remote_rank_to_agent_name = {} path = make_zmq_path("tcp", host, port) @@ -1078,7 +1059,6 @@ def _nixl_handshake( ) # Check compatibility hash BEFORE decoding agent metadata - assert self.compat_hash is not None if ( self.enforce_compat_hash and handshake_payload.compatibility_hash != self.compat_hash @@ -1287,20 +1267,6 @@ def request_ready(f: Future[Any], entry=(req_id, meta)): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - self.kv_topo = TpKVTopology( - tp_rank=self.tp_rank, - engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=self.attn_backend, - tensor_shape=next(iter(kv_caches.values())).shape, - ) - self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks - ) - if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -1335,21 +1301,29 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). + split_k_and_v = self.kv_topo.split_k_and_v tensor_size_bytes = None + # TODO (NickLucche): Get kernel_block_size in a cleaner way + # NHD default "view" for non-MLA cache + if self.device_type == "cpu": + block_size_position = -2 + else: + block_size_position = -2 if self.use_mla else -3 + # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): - cache_list = ( - cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches] - ) + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: base_addr = cache.data_ptr() if base_addr in seen_base_addresses: continue - kernel_block_size = cache.shape[self.kv_topo.block_size_position] + kernel_block_size = cache.shape[block_size_position] + if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" @@ -1411,7 +1385,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 @@ -1467,7 +1440,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size=self.block_size, ) # Wrap metadata in payload with hash for defensive decoding - assert self.compat_hash is not None encoder = msgspec.msgpack.Encoder() self.xfer_handshake_metadata = NixlHandshakePayload( compatibility_hash=self.compat_hash, @@ -1489,8 +1461,6 @@ def register_local_xfer_handler( register another local_xfer_handler using remote block len to ensure data copy correctness. """ - assert self.kv_topo is not None - block_size_ratio = self.block_size // block_size blocks_data = [] for i, base_addr in enumerate(self.seen_base_addresses): @@ -1603,7 +1573,6 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) if engine_id not in self.dst_num_blocks: @@ -1731,10 +1700,7 @@ def _validate_remote_agent_handshake( """ remote_engine_id = nixl_agent_meta.engine_id - assert ( - self._tp_size[remote_engine_id] == remote_tp_size - and self.kv_topo is not None - ) + assert self._tp_size[remote_engine_id] == remote_tp_size tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -1871,7 +1837,6 @@ def post_process_device_kv_on_receive( if len(self.device_kv_caches) == 0: return assert block_size_ratio >= 1, "Only nP < nD supported currently." - assert self.kv_topo is not None if self.enable_permute_local_kv and block_size_ratio > 1: logger.debug( "Post-processing device kv cache on receive by converting " @@ -1891,7 +1856,7 @@ def post_process_device_kv_on_receive( block_size_ratio, ) - split_k_and_v = self.kv_topo.split_k_and_v + split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first) for block_ids in block_ids_list: indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) @@ -1916,7 +1881,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: The scheduler process (via the MultiprocExecutor) will use this output to track which workers are done. """ - assert self.kv_topo is not None done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) @@ -1986,7 +1950,6 @@ def _get_new_notifs(self) -> set[str]: are reading from the same producer (heterogeneous TP scenario), wait for all consumers to be done pulling. """ - assert self.kv_topo is not None notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: @@ -2146,7 +2109,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - assert meta.remote is not None and self.kv_topo is not None + assert meta.remote is not None remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( meta.remote.engine_id ) @@ -2215,7 +2178,10 @@ def _read_blocks( local_xfer_side_handle: int, remote_xfer_side_handle: int, ): - assert self.kv_topo is not None + """ + Post a READ point-to-point xfer request from a single local worker to + a single remote worker. + """ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: local_block_ids = self.get_mapped_blocks( @@ -2448,7 +2414,6 @@ def get_backend_aware_kv_block_len(self, layer_idx: int) -> int: For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ - assert self.kv_topo is not None if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2