nixl refactor [3/N]: extract model-specific logic into ModelBlockTransferPolicy#40157
nixl refactor [3/N]: extract model-specific logic into ModelBlockTransferPolicy#40157ZhanqiuHu wants to merge 23 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the KV transfer system by replacing the TpKVTopology class with a more modular TransferTopology and introducing a ModelBlockTransferPolicy hierarchy to encapsulate model-specific transfer logic for Dense and Mamba architectures. The changes centralize per-engine transfer state into dedicated dataclasses and move complex transfer planning out of the worker and into policy classes. Feedback identifies a critical missing assignment to self.num_descs in the NIXL worker that could lead to data corruption in Mamba-hybrid scenarios and suggests consolidating duplicated tp_ratio calculation logic into a shared utility function.
| """Register the KV Cache data in nixl.""" | ||
| self.kv_topo = TpKVTopology( | ||
| self.transfer_topo = TransferTopology( | ||
| tp_rank=self.tp_rank, | ||
| tp_size=self.world_size, | ||
| block_size=self.block_size, | ||
| engine_id=self.engine_id, | ||
| remote_tp_size=self._tp_size, # shared state | ||
| remote_block_size=self._block_size, # shared state | ||
| is_mla=self.use_mla, | ||
| total_num_kv_heads=self.model_config.get_total_num_kv_heads(), | ||
| attn_backends=self.attn_backends, | ||
| # SSM States come in tuples (ssm, conv) | ||
| tensor_shape=next(iter(kv_caches.values())).shape | ||
| if not self._has_mamba | ||
| else None, | ||
| is_mamba=self._has_mamba, | ||
| tensor_shape=self.transfer_policy.get_tensor_shape(kv_caches), | ||
| is_mamba=self.transfer_policy.is_mamba, | ||
| ) | ||
| self.compat_hash = compute_nixl_compatibility_hash( | ||
| self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks | ||
| self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks | ||
| ) | ||
|
|
||
| if self.use_host_buffer: |
There was a problem hiding this comment.
The refactored register_kv_caches method appears to be missing the assignment to self.num_descs. In the previous implementation, this variable was used to track the number of FlashAttention (FA) descriptors. It is initialized to 0 (line 228) and used later in _read_blocks_for_req -> build_src_split_handles (line 1008) to distinguish between FA and Mamba descriptors in the src_blocks_data list. If self.num_descs remains 0, all descriptors will be incorrectly treated as Mamba descriptors during split handle computation, which will lead to incorrect data slicing and potential corruption in heterogeneous TP scenarios.
| tp_ratio = ( | ||
| tp_size // info.remote_tp_size | ||
| if tp_size >= info.remote_tp_size | ||
| else -(info.remote_tp_size // tp_size) | ||
| ) # noqa: E501 |
There was a problem hiding this comment.
The logic for calculating tp_ratio is duplicated across several methods in MambaModelBlockTransferPolicy (fa_rank_offset, needs_split_handles, describe_mamba, build_engine_transfer_info) and TransferTopology. This duplication is a maintenance risk.
Consider moving this logic to a shared utility function in vllm/distributed/kv_transfer/kv_connector/utils.py. Furthermore, fa_rank_offset should be updated to accept tp_ratio as an argument directly, as it is already calculated in its caller _build_fa_remote_descs (line 891).
169c02c to
b71d8d9
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
b71d8d9 to
169c02c
Compare
NickLucche
left a comment
There was a problem hiding this comment.
Let's try to clear out state separation a bit more neatly and identify what should stay within the connector worker and what is part of the policy.
I feel like some of the interfaces below get pieces and bits of info from worker, which may signal that we're actually just stuffing code out of the worker without proper separation.
Let's iterate to try and simplify ModelBlockTransferPolicy interface so that as a non-functional requirement we can have deleting
| # Per-engine transfer info (data operations) | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| # TODO (ZhanqiuHu): Revisit data packing for local facts and remote facts. |
| tp_rank: int, | ||
| tp_size: int, | ||
| is_mla: bool, | ||
| total_num_kv_heads: int, | ||
| is_kv_layout_blocks_first: bool, | ||
| local_block_len: int, | ||
| remote_tp_size: int, | ||
| remote_block_size: int, | ||
| remote_block_len: int, | ||
| remote_physical_blocks_per_logical: int, |
There was a problem hiding this comment.
I fee like this method could just take in a transfer topology and local/remote NixlAgentMetadata
| block_size_ratio: float | None = None, | ||
| physical_blocks_per_logical: int = 1, | ||
| ) -> np.ndarray: | ||
| """Compute NIXL descriptor IDs for a set of block IDs.""" |
There was a problem hiding this comment.
nit
| """Compute NIXL descriptor IDs for a set of block IDs.""" | |
| """Compute NIXL descriptor IDs for a given set of block IDs.""" |
| # Input | ||
| block_ids: BlockIds, |
There was a problem hiding this comment.
| # Input | |
| block_ids: BlockIds, | |
| block_ids: BlockIds, |
| class ModelBlockTransferPolicy(ABC): | ||
| """Abstract base for model-specific block transfer logic. | ||
|
|
||
| Encapsulates genuinely model-specific algorithms: descriptor building, | ||
| transfer info computation, split handles, read spec filtering, and | ||
| orchestration. Simple per-layer branches (``isinstance(MambaSpec)``) | ||
| and block-ID mapping remain on ``worker.py``. | ||
| """ |
There was a problem hiding this comment.
I think we should highlight the "lifecycle" we want to abstract a bit more:
- local descs registration
- remote descs registration (during handhshake)
- block_ids -> desc_ids mapping
| # ------------------------------------------------------------------ | ||
| # Local descriptor building (concrete default = FA-only) | ||
| # ------------------------------------------------------------------ |
There was a problem hiding this comment.
| # ------------------------------------------------------------------ | |
| # Local descriptor building (concrete default = FA-only) | |
| # ------------------------------------------------------------------ | |
| # ------------------------------------------------------------------ | |
| # Local descriptor building | |
| # ------------------------------------------------------------------ |
| # ------------------------------------------------------------------ | ||
| # Descriptor ID computation (abstract — genuinely different per model) | ||
| # ------------------------------------------------------------------ |
There was a problem hiding this comment.
| # ------------------------------------------------------------------ | |
| # Descriptor ID computation (abstract — genuinely different per model) | |
| # ------------------------------------------------------------------ | |
| # ------------------------------------------------------------------ | |
| # Descriptor ID computation | |
| # ------------------------------------------------------------------ |
|
|
||
| # ------------------------------------------------------------------ | ||
| # Remote descriptor building (abstract — genuinely different) | ||
| # ------------------------------------------------------------------ |
| layer_specs: dict[str, KVCacheSpec], | ||
| physical_blocks_per_logical: int, | ||
| tp_size: int, | ||
| ) -> ModelBlockTransferPolicy: |
There was a problem hiding this comment.
layer_spec is a direct descendent of kv_cache_config. I think we can just pull that from the worker and keep it here.
It would be ideal if we got to the point this method could be renamed to from_kv_cache_config, but I see this is non-trivial. Let's chat about it
|
Hi @ZhanqiuHu, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1 similar comment
|
Hi @ZhanqiuHu, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…delBlockTransferPolicy Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…cal dict Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
… and remote_tp_size Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…fo, remove defaults, regroup args Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…mote and physical_blocks_per_logical, regroup Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…npack read spec loop Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…omputation Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…blocks_per_logical removal Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_policy and use_mla, set num_descs Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…nsfer_info Add physical_blocks_per_logical to TransferTopology and pass transfer_topo directly to build_engine_transfer_info, reducing the method's parameter count from 10 to 6. Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…11→4) Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_id (7→4) Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
40ff123 to
d9d877d
Compare
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Based on #39529