From 4225f1248911fcd8a9e916abee9692ea1ee1d428 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 10:29:47 +0000 Subject: [PATCH 1/4] init Signed-off-by: NickLucche --- .../tp_config_sweep_accuracy_test.sh | 3 + .../kv_connector/unit/test_nixl_connector.py | 132 +++++- .../kv_transfer/kv_connector/utils.py | 72 ++- .../kv_connector/v1/nixl_connector.py | 429 +++++++++++------- 4 files changed, 436 insertions(+), 200 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh index 9308c81da063..422ea851c6d0 100755 --- a/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh @@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh" configs=( "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" + "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1" "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" + "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) + "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) ) run_tests() { diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 66804fa671c7..24daf27bb52e 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -407,22 +407,43 @@ def _nixl_handshake( assert expected_engine_id == self.REMOTE_ENGINE_ID - remote_agent_name = self.add_remote_agent( - NixlAgentMetadata( - engine_id=self.REMOTE_ENGINE_ID, - agent_metadata=FakeNixlWrapper.AGENT_METADATA, - kv_caches_base_addr=[0], - device_id=0, - num_blocks=1, - block_lens=self.block_len_per_layer, - # `self.kv_cache_layout` is only forced to HND when vllm engine - # is started. We mock HND here. - kv_cache_layout="HND", - block_size=self.block_size, - ), - remote_tp_size=remote_tp_size, - ) - return {0: remote_agent_name} + # Adjust remote block length metadata to satisfy heterogeneous TP + # invariants enforced during handshake validation. + remote_block_lens = list(self.block_len_per_layer) + tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) + if remote_tp_size > self.world_size: + # P TP > D TP case, block_len of remote is smaller + remote_block_lens = [ + block_len // (-tp_ratio) for block_len in remote_block_lens + ] + elif remote_tp_size < self.world_size: + remote_block_lens = [ + block_len * tp_ratio for block_len in remote_block_lens + ] + + # When remote tp_size > local tp_size, handshake with multiple + # remote ranks. + num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio + remote_agents: dict[int, str] = {} + for remote_tp_rank in range(num_hanshakes): + remote_agent_name = self.add_remote_agent( + NixlAgentMetadata( + engine_id=self.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + device_id=0, + num_blocks=1, + block_lens=remote_block_lens, + # `self.kv_cache_layout` is only forced to HND when vllm engine + # is started. We mock HND here. + kv_cache_layout="HND", + block_size=self.block_size, + ), + remote_tp_rank=remote_tp_rank, + remote_tp_size=remote_tp_size, + ) + remote_agents[remote_tp_rank] = remote_agent_name + return remote_agents class TestNixlHandshake: @@ -453,7 +474,13 @@ def test_multi_xfer_one_engine( vllm_config, connector.engine_id, hand_shake_latency=0 ) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) - connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) + worker = connector.connector_worker + worker.nixl_wrapper.set_cycles_before_xfer_done(3) + # simulate handshake + worker.dst_xfer_side_handles = { + FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} + } + worker.kv_cache_layout = "HND" num_xfers = 4 while True: # For the same request_id, initiate multiple xfers across different @@ -567,6 +594,71 @@ def test_async_load_kv( return raise TimeoutError("Took too long to complete async handshake.") + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + @pytest.mark.parametrize("local_tp_size", [1, 2]) + def test_prefill_tp_size_greater_than_decode_tp_size( + self, local_tp_size: int, dist_init + ): + """ + Verify remote TP > local TP handshake succeeds with different + remote configurations. + """ + + vllm_config = create_vllm_config() + local_tp_size = 1 + vllm_config.parallel_config.tensor_parallel_size = local_tp_size + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + + def check_handshake(remote_tp_size: int): + tp_ratio = remote_tp_size // local_tp_size + assert set(remote_agents.keys()) == set(range(tp_ratio)) + + remote_engine_id = worker.REMOTE_ENGINE_ID + assert worker._tp_size[remote_engine_id] == remote_tp_size + assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + # ensure src_xfer_side_chunked_handles is populated with tpratio chunks + assert -tp_ratio in worker.src_xfer_side_chunked_handles + assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio + assert remote_engine_id in worker.dst_xfer_side_handles + assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set( + range(tp_ratio) + ) + + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=2, + expected_engine_id=worker.REMOTE_ENGINE_ID, + ) + check_handshake(2) + + # NOTE flexiblity: a second remote with higher number of ranks is + # discovered. This is not a scenario we actively support right now, but + # the connector allows it. + worker.REMOTE_ENGINE_ID = "remote_engine_2" + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=6, + expected_engine_id=worker.REMOTE_ENGINE_ID, + ) + check_handshake(6) + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, @@ -672,7 +764,6 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): with pytest.raises(RuntimeError): # mismatched layout is expected to fail worker.add_remote_agent(meta, remote_tp_size=2) - with pytest.raises(AssertionError): worker.add_remote_agent(meta, remote_tp_size=1) @patch( @@ -1360,7 +1451,8 @@ def test_shutdown_cleans_up_resources(dist_init): ): worker._recving_transfers = {"req1": [123]} worker.src_xfer_side_handle = 456 - worker.dst_xfer_side_handles = {"engine1": 789} + worker.src_xfer_side_chunked_handles = {-2: [456]} + worker.dst_xfer_side_handles = {"engine1": {0: 789}} worker._remote_agents = {"engine1": {0: "agent1"}} worker._registered_descs = ["desc1", "desc2"] @@ -1381,7 +1473,7 @@ def test_shutdown_cleans_up_resources(dist_init): mock_listener.join.assert_called_once() mock_rel_xfer.assert_called_once_with(123) - assert mock_rel_dlist.call_count == 2 + assert mock_rel_dlist.call_count == 3 mock_rel_dlist.assert_any_call(456) # src handle mock_rel_dlist.assert_any_call(789) # dst handle mock_rem_agent.assert_called_once_with("agent1") diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 117d159e25e7..56b68c8f8d03 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -21,6 +21,8 @@ logger = init_logger(__name__) +EngineId = str + def get_kv_connector_cache_layout(): # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is @@ -207,14 +209,13 @@ class TpKVTopology: Helper class for tensor parallel and KV topology information for mapping between local and remote TP workers. """ - tp_rank: int - remote_tp_size: dict[str, int] + remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int attn_backend: type[AttentionBackend] - engine_id: str - remote_block_size: dict[str, int] + engine_id: EngineId + remote_block_size: dict[EngineId, int] def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -238,7 +239,9 @@ def is_kv_layout_blocks_first(self) -> bool: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) + return not ( + self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first + ) @property def tp_size(self) -> int: @@ -256,13 +259,23 @@ def tp_ratio( Calculate the tensor parallel ratio between local and remote TP. We can think of it as the number of local TP workers-per-remote TP workers. Local workers will read from the same remote TP worker in - groups of size `tp_ratio`. + groups of size `tp_ratio`.If remote tp_size > local tp_size, the + ratio is flipped (remote_size/local_size) and the returned value is + negative. """ - assert self.tp_size % remote_tp_size == 0, ( - f"Local tensor parallel size {self.tp_size} is not divisible " - f"by remote tensor parallel size {remote_tp_size}." - ) - return self.tp_size // remote_tp_size + if self.tp_size >= remote_tp_size: + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + else: + assert remote_tp_size % self.tp_size == 0, ( + f"Remote tensor parallel size {remote_tp_size} is not divisible " + f"by local tensor parallel size {self.tp_size}." + ) + # P TP > D TP case, return the ratio as negative + return -remote_tp_size // self.tp_size def block_size_ratio( self, @@ -279,19 +292,19 @@ def block_size_ratio( def tp_ratio_from_engine_id( self, - remote_engine_id: str, + remote_engine_id: EngineId, ) -> int: remote_tp_size = self.remote_tp_size[remote_engine_id] return self.tp_ratio(remote_tp_size) def block_size_ratio_from_engine_id( self, - remote_engine_id: str, + remote_engine_id: EngineId, ) -> float: remote_block_size = self.remote_block_size[remote_engine_id] return self.block_size_ratio(remote_block_size) - def is_kv_replicated(self, engine_id: str) -> bool: + def is_kv_replicated(self, engine_id: EngineId) -> bool: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. @@ -299,24 +312,35 @@ def is_kv_replicated(self, engine_id: str) -> bool: tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - def replicates_kv_cache(self, remote_engine_id: str) -> bool: + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: # MLA is always replicated as the hidden dim can't be split. return self.is_mla or self.is_kv_replicated(remote_engine_id) - def get_target_remote_rank( + def get_target_remote_ranks( self, remote_tp_size: int, - ) -> int: + ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank - (on D) will read from. + (on D) will read from. When remote tp_size > local tp_size, we + read from multiple remote ranks. """ tp_ratio = self.tp_ratio(remote_tp_size) - return self.tp_rank // tp_ratio - - def get_target_remote_rank_from_engine_id( + if tp_ratio > 0: + return [self.tp_rank // tp_ratio] + else: + # P TP > D TP case, D reads from |tp_ratio| remote workers. + tp_ratio = -tp_ratio + if self.is_mla: + # When cache is replicated on remote, we only need to read + # from one remote (they all have the same cache). Fan out + # transfers to avoid bottlenecks on single remote. + return [self.tp_rank * tp_ratio] + return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] + + def get_target_remote_ranks_from_engine_id( self, - remote_engine_id: str, - ) -> int: + remote_engine_id: EngineId, + ) -> list[int]: remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.get_target_remote_rank(remote_tp_size) + return self.get_target_remote_ranks(remote_tp_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fb4b8ac391af..62dee08bbe78 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology +from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology, EngineId from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, @@ -56,7 +56,6 @@ from vllm.v1.request import Request TransferHandle = int -EngineId = str ReqId = str # @@ -873,9 +872,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.copy_blocks: CopyBlocksOp | None = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local - # rank will still only pull from a single remote TP worker. - self.kv_caches_base_addr: dict[EngineId, list[int]] = {} self.device_id: int = 0 + # Current rank may pull from multiple remote TP workers. + self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = ( + defaultdict(dict) + ) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -883,10 +884,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.num_layers = 0 # nixl_prepped_dlist_handle. - self.src_xfer_side_handle: int = 0 - self.src_xfer_side_handles: dict[int, int] = {} + self.src_xfer_handles_by_block_size: dict[int, int] = {} + # Populated dynamically during handshake based on remote configuration. + # Keep track of regions at different tp_ratio values. tp_ratio->handles + # FIXME change to remote tp_size + self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[EngineId, int] = {} + self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( + dict + ) # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -980,100 +986,108 @@ def _nixl_handshake( start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - - # Handshake only with the remote TP rank that current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) + # When target instance TP > local TP, we need to perform multiple + # handshakes. Do it in a single background job for simplicity. + # Regardless, only handshake with the remote TP rank(s) that current + # local rank will read from. Note that With homogeneous TP, + # this happens to be the same single rank_i. + p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) + remote_rank_to_agent_name = {} path = make_zmq_path("tcp", host, port) - logger.debug( - "Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank - ) - # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: - msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) - # Set receive timeout to 5 seconds to avoid hanging on dead server - sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds - sock.send(msg) - handshake_bytes = sock.recv() - - # Decode handshake payload to get compatibility hash - handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload) - try: - handshake_payload = handshake_decoder.decode(handshake_bytes) - except (msgspec.DecodeError, msgspec.ValidationError) as e: - raise RuntimeError( - f"Failed to decode NixlHandshakePayload. This likely indicates " - f"an incompatibility between connector version. Error: {e}" - ) from e + for remote_rank in p_remote_ranks: + logger.debug( + "Querying metadata on path: %s at remote tp rank %s", path, remote_rank + ) - got_metadata_time = time.perf_counter() - logger.debug( - "NIXL handshake: get metadata took: %s", got_metadata_time - start_time - ) + # Send query for the request. + msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank)) + # Set receive timeout to 5 seconds to avoid hanging on dead server + sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds + sock.send(msg) + handshake_bytes = sock.recv() - # Check compatibility hash BEFORE decoding agent metadata - if ( - self.enforce_compat_hash - and handshake_payload.compatibility_hash != self.compat_hash - ): - raise RuntimeError( - f"NIXL compatibility hash mismatch. " - f"Local: {self.compat_hash}, " - f"Remote: {handshake_payload.compatibility_hash}. " - f"Prefill and decode instances have incompatible configurations. " - f"This may be due to: different vLLM versions, models, dtypes, " - f"KV cache layouts, attention backends, etc. " - f"Both instances must use identical configurations." - f"Disable this check using " - f'--kv-transfer-config \'{{"kv_connector_extra_config": ' - f'{{"enforce_handshake_compat": false}}}}\'' + # Decode handshake payload to get compatibility hash + handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload) + try: + handshake_payload = handshake_decoder.decode(handshake_bytes) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + raise RuntimeError( + f"Failed to decode NixlHandshakePayload. This likely indicates " + f"an incompatibility between connector version. Error: {e}" + ) from e + + got_metadata_time = time.perf_counter() + logger.debug( + "Querying metadata on path: %s at remote rank %s", path, remote_rank ) - logger.info( - "NIXL compatibility check passed (hash: %s)", - handshake_payload.compatibility_hash, - ) + # Check compatibility hash BEFORE decoding agent metadata + if ( + self.enforce_compat_hash + and handshake_payload.compatibility_hash != self.compat_hash + ): + raise RuntimeError( + f"NIXL compatibility hash mismatch. " + f"Local: {self.compat_hash}, " + f"Remote: {handshake_payload.compatibility_hash}. " + f"Prefill and decode instances have incompatible configurations. " + f"This may be due to: different vLLM versions, models, dtypes, " + f"KV cache layouts, attention backends, etc. " + f"Both instances must use identical configurations." + f"Disable this check using " + f'--kv-transfer-config \'{{"kv_connector_extra_config": ' + f'{{"enforce_handshake_compat": false}}}}\'' + ) - # Decode agent metadata - metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - try: - metadata = metadata_decoder.decode( - handshake_payload.agent_metadata_bytes + logger.info( + "NIXL compatibility check passed (hash: %s)", + handshake_payload.compatibility_hash, ) - except (msgspec.DecodeError, msgspec.ValidationError) as e: - # This should not happen if hash matched - raise RuntimeError( - f"Failed to decode NixlAgentMetadata. Error: {e}" - ) from e - # Ensure engine id matches. - if metadata.engine_id != expected_engine_id: - raise RuntimeError( - f"Remote NIXL agent engine ID mismatch. " - f"Expected {expected_engine_id}," - f"received {metadata.engine_id}." + # Decode agent metadata + metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + try: + metadata = metadata_decoder.decode( + handshake_payload.agent_metadata_bytes + ) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + # This should not happen if hash matched + raise RuntimeError( + f"Failed to decode NixlAgentMetadata. Error: {e}" + ) from e + + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) + # FIXME move to _validate + assert metadata.block_size <= self.block_size, ( + "nP > nD is not supported yet." ) + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) + setup_agent_time = time.perf_counter() - # Register Remote agent. - assert metadata.block_size <= self.block_size, ( - "nP > nD is not supported yet." - ) - remote_agent_name = self.add_remote_agent( - metadata, p_remote_rank, remote_tp_size - ) - - setup_agent_time = time.perf_counter() - logger.debug( - "NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time, - ) - - # Remote rank -> agent name. - return {p_remote_rank: remote_agent_name} + # Register Remote agent. + remote_agent_name = self.add_remote_agent( + metadata, remote_rank, remote_tp_size + ) + logger.debug( + "NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) + remote_rank_to_agent_name[remote_rank] = remote_agent_name + return remote_rank_to_agent_name def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ @@ -1283,7 +1297,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 - self.kv_caches_base_addr[self.engine_id] = seen_base_addresses + self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1310,9 +1324,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register local/src descr for NIXL xfer. self.seen_base_addresses = seen_base_addresses - self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size) - - self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle + self.src_xfer_handles_by_block_size[self.block_size] = self.register_local_xfer_handler(self.block_size) # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. @@ -1340,8 +1352,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], device_id=self.device_id, + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, kv_cache_layout=self.kv_cache_layout @@ -1421,10 +1433,12 @@ def add_remote_agent( In particular, handle both homogeneous and heterogeneous TP. The former requires local rank_i to read from remote rank_i. - The latter, assuming D.world_size > P.world_size, requires that two or - more local TP worker share the xfer from a single TP worker. + The latter, in the case of D.world_size < P.world_size, requires that a + local (D) TP worker reads from multiple remote (P) TP workers. + Conversely, assuming D.world_size > P.world_size, two or more local TP + workers will read from a single remote TP worker. - Here's an example (non-MLA case): + Here's an example for the last case described above (non-MLA): rank_offset p_remote_tp_rank (kv split no) @@ -1474,9 +1488,6 @@ def add_remote_agent( nixl_agent_meta.agent_metadata ) - # Handle tp_size>num_kv_heads: replicate KV cache. - replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id) - # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. # Example: @@ -1490,14 +1501,53 @@ def add_remote_agent( self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - + self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( + nixl_agent_meta.kv_caches_base_addr + ) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. + # This is 1 when P and D `--tensor-parallel-size` match. Otherwise, + # this is the ratio between the two sizes. tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) + # Handle tp_size>num_kv_heads: replicate KV cache. + indexes_into_remote = ( + not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 + ) + + logger.debug( + "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", + engine_id, + remote_tp_rank, + tp_ratio, + ) + + # FIXME refactor into self.register_local_xfer_handler(self.block_size) + ### (Optional) Register local agent memory regions. MLA is not split. + if ( + tp_ratio < 0 + and not self.use_mla + and tp_ratio not in self.src_xfer_handles_by_tp_ratio + ): + # Remote tp_size > local tp_size: read from multiple remote ranks. + # Logically "split" own regions into |tp_ratio| chunks. Mind that + # we only do this once per remote tp_size (replica-friendly). + self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] + for i in range(-tp_ratio): + blocks_data = [] + for memory_region in self.src_blocks_data: + addr, local_block_len, own_tp_rank = memory_region + # Computing block len layer by layer allows for different + # block sizes to be used. + remote_block_len = local_block_len // (-tp_ratio) + addr = addr + i * remote_block_len + blocks_data.append((addr, remote_block_len, own_tp_rank)) + descs = self.nixl_wrapper.get_xfer_descs( + blocks_data, self.nixl_memory_type + ) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding @@ -1507,15 +1557,18 @@ def add_remote_agent( # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - remote_kv_block_len = kv_block_len // block_size_ratio + # Read our whole local region size from remote. + local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: # using remote kv_block_len as transfer unit - kv_block_len = remote_kv_block_len + local_block_len = remote_kv_block_len + + if tp_ratio < 0 and not self.use_mla: + # Remote tp is bigger: read a chunk of local region from remote + local_block_len = local_block_len // (-tp_ratio) rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len - if not replicates_kv_cache - else 0 + self.tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1524,7 +1577,7 @@ def add_remote_agent( # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) + blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id)) if self.kv_topo.is_kv_layout_blocks_first: # With FlashInfer index V separately to allow head splitting. @@ -1533,7 +1586,7 @@ def add_remote_agent( addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 blocks_data.append( - (v_addr, kv_block_len, nixl_agent_meta.device_id) + (v_addr, local_block_len, nixl_agent_meta.device_id) ) logger.debug( @@ -1546,14 +1599,14 @@ def add_remote_agent( # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs + self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( + self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs) ) if block_size_ratio > 1: # when prefill with smaller block_size, we need to init a # new handler with same block_len to match - self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( + self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = ( self.register_local_xfer_handler(nixl_agent_meta.block_size) ) @@ -1574,7 +1627,9 @@ def _validate_remote_agent_handshake( block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( remote_engine_id ) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + # Num kv_heads > tp_size and P TP > D TP case, not supported + assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) + assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) @@ -1610,23 +1665,51 @@ def _validate_remote_agent_handshake( == nixl_agent_meta.block_lens[i] ), "KV cache sizes must match between P and D when replicated" else: + if tp_ratio != 1 and self.device_type == "xpu": + # XPU uses NHD, hence it does not support splitting on H + raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) - assert ( - remote_block_len - == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) + if tp_ratio > 0: + # Remote NHD/H'D*tp_ratio=N -page_size- + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + # Remote tp is smaller: remote block_len size is bigger + assert remote_block_len == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype." + ) # noqa: E501 + else: + # TODO (NickLucche) to support + assert block_size_ratio == 1, "Different local/remote block sizes are not supported when P TP > D TP." + # Remote NHD/(H'D/tp_ratio)=N -page_size- + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] // (-tp_ratio) + ) + # Remote tp is bigger: remote block_len size is smaller + assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype." + ) # noqa: E501 + + if self._use_flashinfer: + # With flashinfer, KV are sent in the same message. + remote_block_size //= 2 + + # We may allow it in the future with logical kvcache manager block_size + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) - # TP workers have same #blocks. + # TP workers that handhshake with same remote have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - + # Same number of regions/~layers. assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): @@ -1872,7 +1955,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: """ done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): - in_progress = False + in_progress = [] for handle in handles: try: xfer_state = self.nixl_wrapper.check_xfer_state(handle) @@ -1882,7 +1965,7 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: self.xfer_stats.record_transfer(res) self.nixl_wrapper.release_xfer_handle(handle) elif xfer_state == "PROC": - in_progress = True + in_progress.append(handle) continue else: logger.error( @@ -1892,7 +1975,6 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: xfer_state, ) self._handle_failed_transfer(req_id, handle) - in_progress = False except Exception: logger.exception( "NIXL transfer exception for request %s. " @@ -1900,11 +1982,13 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: req_id, ) self._handle_failed_transfer(req_id, handle) - in_progress = False if not in_progress: + # Only report request as completed when all transfers are done. done_req_ids.add(req_id) del transfers[req_id] + else: + transfers[req_id] = in_progress return done_req_ids def _handle_failed_transfer(self, req_id: str, handle: int): @@ -1982,18 +2066,47 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None - logger.debug( - "Remote agent %s available, calling _read_blocks for req %s", - meta.remote.engine_id, - req_id, - ) - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote.engine_id, - remote_request_id=meta.remote.request_id, - local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote.block_ids, + remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( + meta.remote.engine_id ) + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) + # D may have to perform multiple reads from different remote ranks. + for i, remote_rank in enumerate(remote_ranks): + remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id] + logger.debug( + "Remote agent %s available, calling _read_blocks" + " on remote rank %s with remote block size %s for req %s", + meta.remote.engine_id, + remote_rank, + remote_block_size, + req_id, + ) + # Get side handles. + if tp_ratio < 0 and not self.use_mla: + assert remote_block_size == self.block_size + # Remote tp_size > local tp_size: we must perform multiple + # reads. Get the memory chunk onto which we will write to. + local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] + else: + # Single read from remote, we write to the whole memory region. + # Handle the case in which remote block size is different from local block size. + local_xfer_side_handle = self.src_xfer_handles_by_block_size[remote_block_size] + + # Destination handle: remote_engine_id -> remote_rank -> handle. + remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ + remote_rank + ] + # FIXME check local physical block ids changes are respected + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote.engine_id, + remote_request_id=meta.remote.request_id, + local_block_ids=meta.local_physical_block_ids, + remote_block_ids=meta.remote.block_ids, + remote_rank=remote_rank, + local_xfer_side_handle=local_xfer_side_handle, + remote_xfer_side_handle=remote_xfer_side_handle, + ) def _read_blocks( self, @@ -2002,7 +2115,14 @@ def _read_blocks( dst_engine_id: str, request_id: str, remote_request_id: str, + remote_rank: int, + local_xfer_side_handle: int, + remote_xfer_side_handle: int, ): + """ + Post a READ point-to-point xfer request from a single local worker to + a single remote worker. + """ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: local_block_ids = self.get_mapped_blocks( @@ -2033,16 +2153,14 @@ def _read_blocks( # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) + # Cap to 1 when P TP > D TP: only a single rank will read from remote. + tp_ratio = max(1, self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)) notif_id = f"{remote_request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id( - dst_engine_id - ) agent_name = self._remote_agents[dst_engine_id][remote_rank] try: self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) @@ -2062,13 +2180,6 @@ def _read_blocks( if num_local_blocks < num_remote_blocks: remote_block_ids = remote_block_ids[-num_local_blocks:] - # Get side handles. - remote_block_size = self.kv_topo.remote_block_size[dst_engine_id] - local_xfer_side_handle = self.src_xfer_side_handles.get( - remote_block_size, self.src_xfer_side_handle - ) - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] - # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. @@ -2276,11 +2387,17 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() - if self.src_xfer_side_handle: - self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) - self.src_xfer_side_handle = 0 - for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): - self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + for handles in self.src_xfer_handles_by_block_size.values(): + for handle in handles: + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_handles_by_block_size.clear() + for handles in self.src_xfer_handles_by_tp_ratio.values(): + for handle in handles: + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_handles_by_tp_ratio.clear() + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) self.dst_xfer_side_handles.clear() for remote_agents in self._remote_agents.values(): for agent_name in remote_agents.values(): From b435a394a38a4b728d1b93976162b8365ae0449d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 23 Oct 2025 15:59:29 +0000 Subject: [PATCH 2/4] address block freeing for MLA more MLA tests Signed-off-by: NickLucche --- tests/out_prefill | 1 + .../kv_connector/unit/test_nixl_connector.py | 100 ++++++++++++++++++ .../kv_transfer/kv_connector/utils.py | 5 - .../kv_connector/v1/nixl_connector.py | 38 +++++-- 4 files changed, 133 insertions(+), 11 deletions(-) create mode 100644 tests/out_prefill diff --git a/tests/out_prefill b/tests/out_prefill new file mode 100644 index 000000000000..929f36c08688 --- /dev/null +++ b/tests/out_prefill @@ -0,0 +1 @@ +error: No justfile found diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 24daf27bb52e..1ae7995e5160 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -659,6 +659,106 @@ def check_handshake(remote_tp_size: int): ) check_handshake(6) + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + @pytest.mark.parametrize("local_tp_size", [1, 2]) + def test_prefill_tp_size_greater_than_decode_tp_size_mla( + self, local_tp_size: int, dist_init + ): + """ + Verify remote TP > local TP handshake succeeds with different + remote configurations for an MLA model. + """ + vllm_config = create_vllm_config() + d_tp_size = 1 + p_tp_size = 2 + + # Build two separate connectors/workers to emulate P TP=2 ranks. + conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER) + conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER) + conn_p0.connector_worker = FakeNixlConnectorWorker( + vllm_config, conn_p0.engine_id, hand_shake_latency=0 + ) + conn_p1.connector_worker = FakeNixlConnectorWorker( + vllm_config, conn_p1.engine_id, hand_shake_latency=0 + ) + + # Force P world size to 2 for both workers and emulate distinct tp_ranks. + # Also enable MLA path so that expected_finished_count is updated. + for rank, worker in enumerate( + (conn_p0.connector_worker, conn_p1.connector_worker) + ): + worker.world_size = p_tp_size + worker.kv_topo.tp_size = p_tp_size + worker.tp_rank = rank + worker.use_mla = True + + req_id = "req-ep-dp2-p0" + now = time.perf_counter() + # Register a request on P that is waiting for consumers to read + # (both workers track it). + conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0 + conn_p0.connector_worker._reqs_to_process.add(req_id) + conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0 + conn_p1.connector_worker._reqs_to_process.add(req_id) + + # Simulate a read notification coming from D with (tp=1, dp=2). + notif = f"{req_id}:{d_tp_size}".encode() + # D0-0->P0 notif + conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: { + "agent": [notif] + } # type: ignore[method-assign] + conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: { + "agent": [notif] + } # type: ignore[method-assign] + + # Trigger notification processing via get_finished(). + done_sending0, _ = conn_p0.get_finished(finished_req_ids=set()) + done_sending1, _ = conn_p1.get_finished(finished_req_ids=set()) + assert req_id in done_sending0 and req_id in done_sending1 + + # E2E aggregation: ensure the aggregated output marks the request + # as finished using the connector's expected_finished_count. + from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput + + aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2) + + out0 = ModelRunnerOutput( + req_ids=[req_id], + req_id_to_index={req_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=done_sending0, + finished_recving=None, + ), + ) + out1 = ModelRunnerOutput( + req_ids=[req_id], + req_id_to_index={req_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=done_sending1, + finished_recving=None, + ), + ) + aggregated = aggregator.aggregate([out0, out1], output_rank=0) + assert aggregated.kv_connector_output is not None + assert aggregated.kv_connector_output.finished_sending == {req_id} + + # Producers cleaned up state for the finished request. + assert req_id not in conn_p0.connector_worker._reqs_to_send + assert req_id not in conn_p0.connector_worker._reqs_to_process + assert req_id not in conn_p1.connector_worker._reqs_to_send + assert req_id not in conn_p1.connector_worker._reqs_to_process + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 56b68c8f8d03..55dc98c13177 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -331,11 +331,6 @@ def get_target_remote_ranks( else: # P TP > D TP case, D reads from |tp_ratio| remote workers. tp_ratio = -tp_ratio - if self.is_mla: - # When cache is replicated on remote, we only need to read - # from one remote (they all have the same cache). Fan out - # transfers to avoid bottlenecks on single remote. - return [self.tp_rank * tp_ratio] return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] def get_target_remote_ranks_from_engine_id( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 62dee08bbe78..30e77503a760 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1923,7 +1923,7 @@ def _get_new_notifs(self) -> set[str]: notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: - req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + req_id, tp_size = notif.decode("utf-8").rsplit(":", 1) if ( req_id not in self._reqs_to_send and req_id not in self._reqs_to_process @@ -1936,9 +1936,22 @@ def _get_new_notifs(self) -> set[str]: ) continue + # NOTE: `tp_ratio` is the opposite when swapping local<>remote + tp_ratio = self.kv_topo.tp_ratio(int(tp_size)) + n_consumers = int(tp_size) + + # Number of reads *per producer* to wait for. + # When remote D TP > local P TP we expect `tp_ratio` reads. + consumers_per_producer = ( + -tp_ratio if n_consumers > self.world_size else 1 + ) + self.consumer_notification_counts_by_req[req_id] += 1 # Wait all consumers (D) to be done reading before freeing. - if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio): + if ( + self.consumer_notification_counts_by_req[req_id] + == consumers_per_producer + ): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] self._reqs_to_process.remove(req_id) @@ -2072,6 +2085,11 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) # D may have to perform multiple reads from different remote ranks. for i, remote_rank in enumerate(remote_ranks): + if self.use_mla and tp_ratio < 0 and i > 0: + # MLA opt: when P TP > D TP, only a single read is executed for + # the first remote rank (cache is duplicated).. + break + remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id] logger.debug( "Remote agent %s available, calling _read_blocks" @@ -2108,6 +2126,16 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_xfer_side_handle=remote_xfer_side_handle, ) + if self.use_mla and tp_ratio < 0: + # ..but we still need to notify the other remote ranks that we + # have the blocks we need so they can update the request state. + notif_id = f"{req_id}:{self.world_size}".encode() + for rank_to_notify, agent in self._remote_agents[ + meta.remote.engine_id + ].items(): + if rank_to_notify != remote_rank: + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + def _read_blocks( self, local_block_ids: list[int], @@ -2151,11 +2179,9 @@ def _read_blocks( # saturate IB with heterogeneous TP sizes. We should remove the staging # blocks until we are ready. - # Number of D TP workers that will read from dst P. Propagate tp_ratio + # Number of D TP workers that will read from dst P. Propagate info # on notification so that dst worker can wait before freeing blocks. - # Cap to 1 when P TP > D TP: only a single rank will read from remote. - tp_ratio = max(1, self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)) - notif_id = f"{remote_request_id}:{tp_ratio}".encode() + notif_id = f"{remote_request_id}:{self.world_size}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. From b74bedf2562cbca08dbbe52b240576005c105a75 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Nov 2025 09:43:45 +0000 Subject: [PATCH 3/4] intel review Signed-off-by: NickLucche --- tests/out_prefill | 1 - .../kv_connector/unit/test_nixl_connector.py | 29 +++++-- .../kv_transfer/kv_connector/utils.py | 4 +- .../kv_connector/v1/nixl_connector.py | 75 ++++++++----------- 4 files changed, 54 insertions(+), 55 deletions(-) delete mode 100644 tests/out_prefill diff --git a/tests/out_prefill b/tests/out_prefill deleted file mode 100644 index 929f36c08688..000000000000 --- a/tests/out_prefill +++ /dev/null @@ -1 +0,0 @@ -error: No justfile found diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 1ae7995e5160..b630f5e99b5c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -391,6 +391,8 @@ def __init__( super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency self.kv_cache_layout = kv_cache_layout + # Mock register_kv_caches attribute needed for tests that do not call it. + self.src_xfer_handles_by_block_size = {self.block_size: 1} def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -431,7 +433,7 @@ def _nixl_handshake( engine_id=self.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], - device_id=0, + device_id=remote_tp_rank, num_blocks=1, block_lens=remote_block_lens, # `self.kv_cache_layout` is only forced to HND when vllm engine @@ -480,6 +482,7 @@ def test_multi_xfer_one_engine( worker.dst_xfer_side_handles = { FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} } + # worker.src_xfer_handles_by_block_size = {worker.block_size: 1} worker.kv_cache_layout = "HND" num_xfers = 4 while True: @@ -558,6 +561,9 @@ def test_async_load_kv( connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id ) + # worker = connector.connector_worker + # worker.src_xfer_handles_by_block_size = {worker.block_size: 1} + metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id="id", @@ -631,9 +637,9 @@ def check_handshake(remote_tp_size: int): remote_engine_id = worker.REMOTE_ENGINE_ID assert worker._tp_size[remote_engine_id] == remote_tp_size assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id) - # ensure src_xfer_side_chunked_handles is populated with tpratio chunks - assert -tp_ratio in worker.src_xfer_side_chunked_handles - assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio + # ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks + assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio + assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio assert remote_engine_id in worker.dst_xfer_side_handles assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set( range(tp_ratio) @@ -691,7 +697,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size_mla( (conn_p0.connector_worker, conn_p1.connector_worker) ): worker.world_size = p_tp_size - worker.kv_topo.tp_size = p_tp_size + worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size} worker.tp_rank = rank worker.use_mla = True @@ -777,6 +783,9 @@ def test_concurrent_load_kv( connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id ) + # Register (mocked) local xfer handler + # worker = connector.connector_worker + # worker.src_xfer_handles_by_block_size = {worker.block_size: 1} metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): @@ -1552,6 +1561,10 @@ def test_shutdown_cleans_up_resources(dist_init): worker._recving_transfers = {"req1": [123]} worker.src_xfer_side_handle = 456 worker.src_xfer_side_chunked_handles = {-2: [456]} + # Mock register_kv_cache which registers local handle + # worker.src_xfer_handles_by_block_size = {worker.block_size: 455} + # P TP = 2 * D TP case, we should register 2 local handles + worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]} worker.dst_xfer_side_handles = {"engine1": {0: 789}} worker._remote_agents = {"engine1": {0: "agent1"}} worker._registered_descs = ["desc1", "desc2"] @@ -1573,8 +1586,10 @@ def test_shutdown_cleans_up_resources(dist_init): mock_listener.join.assert_called_once() mock_rel_xfer.assert_called_once_with(123) - assert mock_rel_dlist.call_count == 3 - mock_rel_dlist.assert_any_call(456) # src handle + assert mock_rel_dlist.call_count == 4 + mock_rel_dlist.assert_any_call(455) # src handle (whole region) + mock_rel_dlist.assert_any_call(456) # src handle (1st chunk) + mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk) mock_rel_dlist.assert_any_call(789) # dst handle mock_rem_agent.assert_called_once_with("agent1") assert mock_dereg.call_count == 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 55dc98c13177..c98099907714 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -280,7 +280,7 @@ def tp_ratio( def block_size_ratio( self, remote_block_size: int, - ) -> float: + ) -> int: """ Calculate the block size ratio between local and remote TP. """ @@ -300,7 +300,7 @@ def tp_ratio_from_engine_id( def block_size_ratio_from_engine_id( self, remote_engine_id: EngineId, - ) -> float: + ) -> int: remote_block_size = self.remote_block_size[remote_engine_id] return self.block_size_ratio(remote_block_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 30e77503a760..7593658c623f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -887,9 +887,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.src_xfer_handles_by_block_size: dict[int, int] = {} # Populated dynamically during handshake based on remote configuration. # Keep track of regions at different tp_ratio values. tp_ratio->handles - # FIXME change to remote tp_size self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} - # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( dict ) @@ -983,9 +982,6 @@ def _nixl_handshake( expected_engine_id: str, ) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" - - start_time = time.perf_counter() - # When target instance TP > local TP, we need to perform multiple # handshakes. Do it in a single background job for simplicity. # Regardless, only handshake with the remote TP rank(s) that current @@ -1001,6 +997,7 @@ def _nixl_handshake( "Querying metadata on path: %s at remote tp rank %s", path, remote_rank ) + start_time = time.perf_counter() # Send query for the request. msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank)) # Set receive timeout to 5 seconds to avoid hanging on dead server @@ -1020,7 +1017,8 @@ def _nixl_handshake( got_metadata_time = time.perf_counter() logger.debug( - "Querying metadata on path: %s at remote rank %s", path, remote_rank + "NIXL handshake: get metadata took: %s", + got_metadata_time - start_time, ) # Check compatibility hash BEFORE decoding agent metadata @@ -1324,7 +1322,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register local/src descr for NIXL xfer. self.seen_base_addresses = seen_base_addresses - self.src_xfer_handles_by_block_size[self.block_size] = self.register_local_xfer_handler(self.block_size) + self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = ( + self.register_local_xfer_handler(self.block_size) + ) # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. @@ -1371,7 +1371,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_local_xfer_handler( self, block_size: int, - ) -> int: + ) -> tuple[int, list[tuple[int, int, int]]]: """ Function used for register local xfer handler with local block_size or Remote block_size. @@ -1419,7 +1419,7 @@ def register_local_xfer_handler( descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. - return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data def add_remote_agent( self, @@ -1522,7 +1522,6 @@ def add_remote_agent( tp_ratio, ) - # FIXME refactor into self.register_local_xfer_handler(self.block_size) ### (Optional) Register local agent memory regions. MLA is not split. if ( tp_ratio < 0 @@ -1568,7 +1567,9 @@ def add_remote_agent( # Remote tp is bigger: read a chunk of local region from remote local_block_len = local_block_len // (-tp_ratio) rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 + self.tp_rank % tp_ratio * remote_kv_block_len + if indexes_into_remote + else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1607,7 +1608,7 @@ def add_remote_agent( # when prefill with smaller block_size, we need to init a # new handler with same block_len to match self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = ( - self.register_local_xfer_handler(nixl_agent_meta.block_size) + self.register_local_xfer_handler(nixl_agent_meta.block_size)[0] ) return remote_agent_name @@ -1665,9 +1666,6 @@ def _validate_remote_agent_handshake( == nixl_agent_meta.block_lens[i] ), "KV cache sizes must match between P and D when replicated" else: - if tp_ratio != 1 and self.device_type == "xpu": - # XPU uses NHD, hence it does not support splitting on H - raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( @@ -1675,21 +1673,18 @@ def _validate_remote_agent_handshake( ) if tp_ratio > 0: - # Remote NHD/H'D*tp_ratio=N -page_size- - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) # Remote tp is smaller: remote block_len size is bigger - assert remote_block_len == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio, ( + assert ( + remote_block_len + == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio + ), ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype." ) # noqa: E501 else: - # TODO (NickLucche) to support - assert block_size_ratio == 1, "Different local/remote block sizes are not supported when P TP > D TP." - # Remote NHD/(H'D/tp_ratio)=N -page_size- - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] // (-tp_ratio) + assert block_size_ratio == 1, ( + "Different local/remote block sizes are not supported when" + " P TP > D TP." ) # Remote tp is bigger: remote block_len size is smaller assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), ( @@ -1697,16 +1692,6 @@ def _validate_remote_agent_handshake( "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype." ) # noqa: E501 - if self._use_flashinfer: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - - # We may allow it in the future with logical kvcache manager block_size - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - # TP workers that handhshake with same remote have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks # Same number of regions/~layers. @@ -1793,7 +1778,7 @@ def permute_device_kv(self, block_ids: list[int]): ) cache.index_copy_(0, indices, permuted_blocks) - def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): + def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]): def _process_local_gt_remote(blocks_to_update, block_size_ratio): n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] remote_block_size = block_size // block_size_ratio @@ -2107,14 +2092,15 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] else: # Single read from remote, we write to the whole memory region. - # Handle the case in which remote block size is different from local block size. - local_xfer_side_handle = self.src_xfer_handles_by_block_size[remote_block_size] - + # Also handle remote block size different from local block size. + local_xfer_side_handle = self.src_xfer_handles_by_block_size[ + remote_block_size + ] + # Destination handle: remote_engine_id -> remote_rank -> handle. remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ remote_rank ] - # FIXME check local physical block ids changes are respected self._read_blocks( request_id=req_id, dst_engine_id=meta.remote.engine_id, @@ -2148,7 +2134,7 @@ def _read_blocks( remote_xfer_side_handle: int, ): """ - Post a READ point-to-point xfer request from a single local worker to + Post a READ point-to-point xfer request from a single local worker to a single remote worker. """ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) @@ -2367,7 +2353,7 @@ def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange ).tolist() - def get_backend_aware_kv_block_len(self, layer_idx: int): + def get_backend_aware_kv_block_len(self, layer_idx: int) -> int: """ Get the block length for one K/V element (K and V have the same size). @@ -2413,9 +2399,8 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() - for handles in self.src_xfer_handles_by_block_size.values(): - for handle in handles: - self.nixl_wrapper.release_dlist_handle(handle) + for handle in self.src_xfer_handles_by_block_size.values(): + self.nixl_wrapper.release_dlist_handle(handle) self.src_xfer_handles_by_block_size.clear() for handles in self.src_xfer_handles_by_tp_ratio.values(): for handle in handles: From 0fd8705bde3a20ea874fb3aa3b06a439355de303 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 12 Dec 2025 18:19:19 +0000 Subject: [PATCH 4/4] nicks review Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 8 +---- .../kv_transfer/kv_connector/utils.py | 27 ++++++++-------- .../kv_connector/v1/nixl_connector.py | 32 ++++++++----------- 3 files changed, 27 insertions(+), 40 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b630f5e99b5c..6324becda6cc 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -482,7 +482,6 @@ def test_multi_xfer_one_engine( worker.dst_xfer_side_handles = { FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} } - # worker.src_xfer_handles_by_block_size = {worker.block_size: 1} worker.kv_cache_layout = "HND" num_xfers = 4 while True: @@ -561,9 +560,6 @@ def test_async_load_kv( connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id ) - # worker = connector.connector_worker - # worker.src_xfer_handles_by_block_size = {worker.block_size: 1} - metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id="id", @@ -1559,10 +1555,8 @@ def test_shutdown_cleans_up_resources(dist_init): patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, ): worker._recving_transfers = {"req1": [123]} - worker.src_xfer_side_handle = 456 - worker.src_xfer_side_chunked_handles = {-2: [456]} # Mock register_kv_cache which registers local handle - # worker.src_xfer_handles_by_block_size = {worker.block_size: 455} + worker.src_xfer_handles_by_block_size = {worker.block_size: 455} # P TP = 2 * D TP case, we should register 2 local handles worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]} worker.dst_xfer_side_handles = {"engine1": {0: 789}} diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index c98099907714..a026cccb8537 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -209,6 +209,7 @@ class TpKVTopology: Helper class for tensor parallel and KV topology information for mapping between local and remote TP workers. """ + tp_rank: int remote_tp_size: dict[EngineId, int] is_mla: bool @@ -239,9 +240,7 @@ def is_kv_layout_blocks_first(self) -> bool: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not ( - self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first - ) + return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) @property def tp_size(self) -> int: @@ -269,13 +268,13 @@ def tp_ratio( f"by remote tensor parallel size {remote_tp_size}." ) return self.tp_size // remote_tp_size - else: - assert remote_tp_size % self.tp_size == 0, ( - f"Remote tensor parallel size {remote_tp_size} is not divisible " - f"by local tensor parallel size {self.tp_size}." - ) - # P TP > D TP case, return the ratio as negative - return -remote_tp_size // self.tp_size + + assert remote_tp_size % self.tp_size == 0, ( + f"Remote tensor parallel size {remote_tp_size} is not divisible " + f"by local tensor parallel size {self.tp_size}." + ) + # P TP > D TP case, return the ratio as negative + return -remote_tp_size // self.tp_size def block_size_ratio( self, @@ -328,10 +327,10 @@ def get_target_remote_ranks( tp_ratio = self.tp_ratio(remote_tp_size) if tp_ratio > 0: return [self.tp_rank // tp_ratio] - else: - # P TP > D TP case, D reads from |tp_ratio| remote workers. - tp_ratio = -tp_ratio - return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] + + # P TP > D TP case, D reads from |tp_ratio| remote workers. + tp_ratio = -tp_ratio + return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] def get_target_remote_ranks_from_engine_id( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7593658c623f..be56eb4e93c1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology, EngineId +from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, @@ -874,9 +874,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> kv_caches_base_addr. For TP case, each local self.device_id: int = 0 # Current rank may pull from multiple remote TP workers. - self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = ( - defaultdict(dict) - ) + # EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer + self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -889,9 +888,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Keep track of regions at different tp_ratio values. tp_ratio->handles self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. - self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( - dict - ) + self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict) # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -994,7 +991,9 @@ def _nixl_handshake( with zmq_ctx(zmq.REQ, path) as sock: for remote_rank in p_remote_ranks: logger.debug( - "Querying metadata on path: %s at remote tp rank %s", path, remote_rank + "Querying metadata on path: %s at remote tp rank %s", + path, + remote_rank, ) start_time = time.perf_counter() @@ -1030,9 +1029,9 @@ def _nixl_handshake( f"NIXL compatibility hash mismatch. " f"Local: {self.compat_hash}, " f"Remote: {handshake_payload.compatibility_hash}. " - f"Prefill and decode instances have incompatible configurations. " - f"This may be due to: different vLLM versions, models, dtypes, " - f"KV cache layouts, attention backends, etc. " + f"Prefill and decode instances have incompatible " + f"configurations. This may be due to: different vLLM versions," + f" models, dtypes, KV cache layouts, attention backends, etc. " f"Both instances must use identical configurations." f"Disable this check using " f'--kv-transfer-config \'{{"kv_connector_extra_config": ' @@ -1063,10 +1062,6 @@ def _nixl_handshake( f"Expected {expected_engine_id}," f"received {metadata.engine_id}." ) - # FIXME move to _validate - assert metadata.block_size <= self.block_size, ( - "nP > nD is not supported yet." - ) # Ensure engine id matches. if metadata.engine_id != expected_engine_id: raise RuntimeError( @@ -1922,8 +1917,8 @@ def _get_new_notifs(self) -> set[str]: continue # NOTE: `tp_ratio` is the opposite when swapping local<>remote - tp_ratio = self.kv_topo.tp_ratio(int(tp_size)) n_consumers = int(tp_size) + tp_ratio = self.kv_topo.tp_ratio(n_consumers) # Number of reads *per producer* to wait for. # When remote D TP > local P TP we expect `tp_ratio` reads. @@ -2116,9 +2111,8 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # ..but we still need to notify the other remote ranks that we # have the blocks we need so they can update the request state. notif_id = f"{req_id}:{self.world_size}".encode() - for rank_to_notify, agent in self._remote_agents[ - meta.remote.engine_id - ].items(): + remote_agents = self._remote_agents[meta.remote.engine_id] + for rank_to_notify, agent in remote_agents.items(): if rank_to_notify != remote_rank: self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)