-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation #37310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
NickLucche
merged 8 commits into
vllm-project:main
from
ZhanqiuHu:fix/mamba-pd-n1-prefill
Mar 19, 2026
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
6e49140
[SSM/Mamba] N-1 prefill for P/D disaggregation
ZhanqiuHu 172e7e6
Refactor: cleaner call-site guards and docstrings
ZhanqiuHu f1c4cb6
Fix ruff SIM102: collapse nested if statements
ZhanqiuHu 2130daf
Use _has_mamba instead of _is_hma_required for N-1 logic
ZhanqiuHu 3a18e8e
Rename _hma_ helpers to _mamba_ for clarity
ZhanqiuHu f29f1b6
Add comment explaining _p_side_truncated preemption guard
ZhanqiuHu 458ca60
add test cases
ZhanqiuHu 17f996f
handle prompt embeddings
ZhanqiuHu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -572,6 +572,10 @@ def __init__( | |
| for g in kv_cache_config.kv_cache_groups | ||
| ) | ||
| ) | ||
| self._has_mamba = any( | ||
| isinstance(g.kv_cache_spec, MambaSpec) | ||
| for g in kv_cache_config.kv_cache_groups | ||
| ) | ||
|
|
||
| logger.info("Initializing NIXL Scheduler %s", engine_id) | ||
| if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: | ||
|
|
@@ -717,6 +721,39 @@ def _nixl_handshake_listener( | |
| logger.warning("Connection listener got unexpected message %s", msg) | ||
| sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) | ||
|
|
||
| def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int: | ||
| """D-side only. Returns N-1 for Mamba models since the decoder | ||
| always recomputes the last token and must start from h(N-1).""" | ||
| if self._has_mamba and num_prompt_tokens > 1: | ||
| return num_prompt_tokens - 1 | ||
| return num_prompt_tokens | ||
|
|
||
| def _truncate_mamba_request_for_prefill(self, request: "Request") -> None: | ||
| """P-side only: drop the last prompt token so the prefiller computes | ||
| h(N-1) instead of h(N). The decoder recomputes the last token to | ||
| derive h(N) correctly. | ||
|
|
||
| Guarded by ``_p_side_truncated`` to avoid repeated truncation if the | ||
| request is preempted and rescheduled.""" | ||
| params = request.kv_transfer_params | ||
| if ( | ||
| params is not None | ||
| # Guard against repeated truncation after preemption/reschedule. | ||
| and not params.get("_p_side_truncated") | ||
| and request.num_prompt_tokens > 1 | ||
| ): | ||
| if request.prompt_token_ids is not None: | ||
| request.prompt_token_ids.pop() | ||
| elif request.prompt_embeds is not None: | ||
| request.prompt_embeds = request.prompt_embeds[:-1] | ||
| else: | ||
| return | ||
|
|
||
| request._all_token_ids.pop() | ||
| request.num_prompt_tokens -= 1 | ||
| request.max_tokens = 1 | ||
| params["_p_side_truncated"] = True | ||
|
Comment on lines
+739
to
+755
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This truncation logic only considers if (
params is not None
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
# This case should not be possible if num_prompt_tokens > 1.
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True |
||
|
|
||
| def get_num_new_matched_tokens( | ||
| self, request: "Request", num_computed_tokens: int | ||
| ) -> tuple[int, bool]: | ||
|
|
@@ -746,10 +783,14 @@ def get_num_new_matched_tokens( | |
| if params is not None and params.get("do_remote_prefill"): | ||
| # Remote prefill: get all prompt blocks from remote. | ||
| token_ids = request.prompt_token_ids or [] | ||
| count = len(token_ids) - num_computed_tokens | ||
| actual = self._mamba_prefill_token_count(len(token_ids)) | ||
| count = actual - num_computed_tokens | ||
| if count > 0: | ||
| return count, True | ||
|
|
||
| if params is not None and params.get("do_remote_decode") and self._has_mamba: | ||
| self._truncate_mamba_request_for_prefill(request) | ||
|
|
||
| # No remote prefill for this request. | ||
| return 0, False | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have this param?
D is not supposed to get here because of
do_remote_decodeguard.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gemini-code-assist made the suggestion here: ZhanqiuHu#4 (comment).
I think this makes sure we don't
-1multiple times if the request got rescheduled. Although it doesn't necessarily needs to be inside the params.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhanqiuHu I don't think it can happen as described here, but this is very much valid for preemptions. Let's add a comment