Skip to content

nixl refactor [3/N]: extract model-specific logic into ModelBlockTransferPolicy#40157

Open
ZhanqiuHu wants to merge 23 commits intovllm-project:mainfrom
ZhanqiuHu:nixl-refactor-3n-policy-extraction
Open

nixl refactor [3/N]: extract model-specific logic into ModelBlockTransferPolicy#40157
ZhanqiuHu wants to merge 23 commits intovllm-project:mainfrom
ZhanqiuHu:nixl-refactor-3n-policy-extraction

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Apr 17, 2026

Based on #39529

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV transfer system by replacing the TpKVTopology class with a more modular TransferTopology and introducing a ModelBlockTransferPolicy hierarchy to encapsulate model-specific transfer logic for Dense and Mamba architectures. The changes centralize per-engine transfer state into dedicated dataclasses and move complex transfer planning out of the worker and into policy classes. Feedback identifies a critical missing assignment to self.num_descs in the NIXL worker that could lead to data corruption in Mamba-hybrid scenarios and suggests consolidating duplicated tp_ratio calculation logic into a shared utility function.

Comment on lines 609 to 625
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
# SSM States come in tuples (ssm, conv)
tensor_shape=next(iter(kv_caches.values())).shape
if not self._has_mamba
else None,
is_mamba=self._has_mamba,
tensor_shape=self.transfer_policy.get_tensor_shape(kv_caches),
is_mamba=self.transfer_policy.is_mamba,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
)

if self.use_host_buffer:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The refactored register_kv_caches method appears to be missing the assignment to self.num_descs. In the previous implementation, this variable was used to track the number of FlashAttention (FA) descriptors. It is initialized to 0 (line 228) and used later in _read_blocks_for_req -> build_src_split_handles (line 1008) to distinguish between FA and Mamba descriptors in the src_blocks_data list. If self.num_descs remains 0, all descriptors will be incorrectly treated as Mamba descriptors during split handle computation, which will lead to incorrect data slicing and potential corruption in heterogeneous TP scenarios.

Comment on lines +1259 to +1263
tp_ratio = (
tp_size // info.remote_tp_size
if tp_size >= info.remote_tp_size
else -(info.remote_tp_size // tp_size)
) # noqa: E501
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for calculating tp_ratio is duplicated across several methods in MambaModelBlockTransferPolicy (fa_rank_offset, needs_split_handles, describe_mamba, build_engine_transfer_info) and TransferTopology. This duplication is a maintenance risk.

Consider moving this logic to a shared utility function in vllm/distributed/kv_transfer/kv_connector/utils.py. Furthermore, fa_rank_offset should be updated to accept tp_ratio as an argument directly, as it is already calculated in its caller _build_fa_remote_descs (line 891).

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-3n-policy-extraction branch 2 times, most recently from 169c02c to b71d8d9 Compare April 21, 2026 18:51
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 21, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 21, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-3n-policy-extraction branch from b71d8d9 to 169c02c Compare April 21, 2026 18:53
@mergify mergify Bot removed the needs-rebase label Apr 21, 2026
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review April 21, 2026 22:06
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to clear out state separation a bit more neatly and identify what should stay within the connector worker and what is part of the policy.

I feel like some of the interfaces below get pieces and bits of info from worker, which may signal that we're actually just stuffing code out of the worker without proper separation.
Let's iterate to try and simplify ModelBlockTransferPolicy interface so that as a non-functional requirement we can have deleting

# Per-engine transfer info (data operations)
# ------------------------------------------------------------------

# TODO (ZhanqiuHu): Revisit data packing for local facts and remote facts.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm?

Comment on lines +85 to +94
tp_rank: int,
tp_size: int,
is_mla: bool,
total_num_kv_heads: int,
is_kv_layout_blocks_first: bool,
local_block_len: int,
remote_tp_size: int,
remote_block_size: int,
remote_block_len: int,
remote_physical_blocks_per_logical: int,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fee like this method could just take in a transfer topology and local/remote NixlAgentMetadata

block_size_ratio: float | None = None,
physical_blocks_per_logical: int = 1,
) -> np.ndarray:
"""Compute NIXL descriptor IDs for a set of block IDs."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
"""Compute NIXL descriptor IDs for a set of block IDs."""
"""Compute NIXL descriptor IDs for a given set of block IDs."""

Comment on lines +129 to +130
# Input
block_ids: BlockIds,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Input
block_ids: BlockIds,
block_ids: BlockIds,

Comment on lines +59 to +66
class ModelBlockTransferPolicy(ABC):
"""Abstract base for model-specific block transfer logic.

Encapsulates genuinely model-specific algorithms: descriptor building,
transfer info computation, split handles, read spec filtering, and
orchestration. Simple per-layer branches (``isinstance(MambaSpec)``)
and block-ID mapping remain on ``worker.py``.
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should highlight the "lifecycle" we want to abstract a bit more:

  • local descs registration
  • remote descs registration (during handhshake)
  • block_ids -> desc_ids mapping

Comment on lines +141 to +143
# ------------------------------------------------------------------
# Local descriptor building (concrete default = FA-only)
# ------------------------------------------------------------------
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ------------------------------------------------------------------
# Local descriptor building (concrete default = FA-only)
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Local descriptor building
# ------------------------------------------------------------------

Comment on lines +122 to +124
# ------------------------------------------------------------------
# Descriptor ID computation (abstract — genuinely different per model)
# ------------------------------------------------------------------
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ------------------------------------------------------------------
# Descriptor ID computation (abstract — genuinely different per model)
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Descriptor ID computation
# ------------------------------------------------------------------

Comment on lines +217 to +220

# ------------------------------------------------------------------
# Remote descriptor building (abstract — genuinely different)
# ------------------------------------------------------------------
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

layer_specs: dict[str, KVCacheSpec],
physical_blocks_per_logical: int,
tp_size: int,
) -> ModelBlockTransferPolicy:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer_spec is a direct descendent of kv_cache_config. I think we can just pull that from the worker and keep it here.
It would be ideal if we got to the point this method could be renamed to from_kv_cache_config, but I see this is non-trivial. Let's chat about it

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 22, 2026

Hi @ZhanqiuHu, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

1 similar comment
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 22, 2026

Hi @ZhanqiuHu, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…delBlockTransferPolicy

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…cal dict

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
… and remote_tp_size

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…fo, remove defaults, regroup args

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…mote and physical_blocks_per_logical, regroup

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…npack read spec loop

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…omputation

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…blocks_per_logical removal

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_policy and use_mla, set num_descs

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…nsfer_info

Add physical_blocks_per_logical to TransferTopology and pass
transfer_topo directly to build_engine_transfer_info, reducing the
method's parameter count from 10 to 6.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…11→4)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_id (7→4)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-3n-policy-extraction branch from 40ff123 to d9d877d Compare April 22, 2026 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants