Skip to content

[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation#37310

Merged
NickLucche merged 8 commits intovllm-project:mainfrom
ZhanqiuHu:fix/mamba-pd-n1-prefill
Mar 19, 2026
Merged

[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation#37310
NickLucche merged 8 commits intovllm-project:mainfrom
ZhanqiuHu:fix/mamba-pd-n1-prefill

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Mar 17, 2026

Follow-up of #36687

Connector-only N-1 prefill fix for Mamba P/D disaggregation. Prevents state corruption where D-side recomputes last prompt token from already-complete state. Single file change: nixl_connector.py (+16 lines), no scheduler changes.

Eval

Quick sanity (2p2d, prompt: The capital of France is):

  • Without fix:
Paris. Thats a weird sentence: The capital of France is is Paris. It repeats is. The user wants exactly that sentence, no
  • With fix:
Paris. No extra punctuation? The sentence includes a period at the end. The user wrote: The capital of France is Paris. So we

lm_eval gsm8k 5 shot:

Setting R1 R2 R3 R4 R5 Mean Std
2p2d 0.8446 0.8453 0.8393 0.8408 0.8378 0.8416 0.0033

Test Summary

Added test cases

  • test_mamba_n1_d_side (3 parametrized cases): D-side returns N-1 = 9 matched tokens for Mamba, N = 10 for FA-only and SWA-only
  • test_mamba_n1_p_side_truncation (1 case, multi-assert): P-side truncates prompt 10→9, sets max_tokens=1, verifies idempotency and non-Mamba skip
  • test_has_mamba_init (3 parametrized cases): _has_mamba / _is_hma_required flags derived correctly from FA+SWA+Mamba, FA+SWA, FA-only configs

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces logic to handle N-1 prefill for Mamba-style models in a distributed setting. A new helper function _hma_prefill_token_count is added to correctly calculate the number of tokens for prefill. The main change is in get_num_new_matched_tokens, which now truncates the prompt for the prefill worker to ensure the correct state is computed and transferred.

My main feedback is regarding a design concern in get_num_new_matched_tokens. The function now has significant side effects, modifying the request object in-place. This violates the command-query separation principle and can make the code harder to maintain. I've added a detailed comment with a suggestion for refactoring.

Comment on lines +762 to +776
# HMA P-side (do_remote_decode): truncate prompt to N-1 tokens so
# the model computes h(N-1) and stops via max_tokens after one
# decode sample (which does NOT update Mamba state).
if (
self._is_hma_required
and params is not None
and params.get("do_remote_decode")
and not params.get("_p_side_truncated")
):
if request.prompt_token_ids and len(request.prompt_token_ids) > 1:
request.prompt_token_ids.pop()
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code modifies the request object, which is a significant side effect within a function named get_num_new_matched_tokens. This violates the principle of command-query separation, where a function should either perform an action (command) or return data (query), but not both. Mixing these concerns can make the code harder to understand, debug, and maintain.

A function with "get" in its name is expected to be read-only. To improve clarity and align with best practices, I recommend refactoring this logic. One approach is to move this request modification logic to a new, private method with a name that clearly indicates its side effect, such as _truncate_hma_request_for_prefill. get_num_new_matched_tokens could then call this new method.

While I cannot suggest adding a new method directly within this review, I strongly advise considering this refactoring for better code quality.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see you've addressed the issue by moving the request modification logic into a separate method. This aligns well with the principle of command-query separation and improves the code's maintainability. Thanks for taking the feedback into account!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since when can gemini reply back lol

@mergify mergify bot added the kv-connector label Mar 17, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the fix/mamba-pd-n1-prefill branch 2 times, most recently from 7a0a639 to 21713a2 Compare March 17, 2026 16:07
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review March 17, 2026 16:19
@ZhanqiuHu ZhanqiuHu force-pushed the fix/mamba-pd-n1-prefill branch from 7ea7321 to 9d9d611 Compare March 17, 2026 18:00
@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add some unit tests for this to make sure the workflow is clear.

logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))

def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming of this functions should be updated too

request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have this param?
D is not supposed to get here because of do_remote_decode guard.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemini-code-assist made the suggestion here: ZhanqiuHu#4 (comment).

Screenshot 2026-03-17 at 2 53 26 PM

I think this makes sure we don't -1 multiple times if the request got rescheduled. Although it doesn't necessarily needs to be inside the params.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu I don't think it can happen as described here, but this is very much valid for preemptions. Let's add a comment

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for Mamba P/D disaggregation by implementing N-1 prefill to prevent state corruption. The changes are confined to nixl_connector.py and seem correct in principle. However, I've identified a scenario where the fix may not be applied correctly when prompts are provided as embeddings, potentially leading to the same state corruption issue. I have provided a suggestion to make the implementation more robust by handling this case.

Comment on lines +738 to +748
if (
params is not None
and not params.get("_p_side_truncated")
and request.prompt_token_ids
and len(request.prompt_token_ids) > 1
):
request.prompt_token_ids.pop()
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This truncation logic only considers prompt_token_ids and will not apply the N-1 prefill fix for Mamba models when prompt_embeds are used. This could lead to state corruption and incorrect outputs in that scenario. The logic should be updated to handle prompt_embeds as well.

        if (
            params is not None
            and not params.get("_p_side_truncated")
            and request.num_prompt_tokens > 1
        ):
            if request.prompt_token_ids is not None:
                request.prompt_token_ids.pop()
            elif request.prompt_embeds is not None:
                request.prompt_embeds = request.prompt_embeds[:-1]
            else:
                # This case should not be possible if num_prompt_tokens > 1.
                return

            request._all_token_ids.pop()
            request.num_prompt_tokens -= 1
            request.max_tokens = 1
            params["_p_side_truncated"] = True

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM nice job

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
@NickLucche NickLucche enabled auto-merge (squash) March 18, 2026 19:18
For HMA (Mamba/SSM) models in P/D disaggregation, the prefiller must
transfer h(N-1) instead of h(N) so the decoder can correctly recompute
the last prompt token.

D-side: _hma_prefill_token_count() helper returns N-1 for HMA models,
used in get_num_new_matched_tokens so the decoder naturally recomputes
the last token from h(N-1).

P-side: _truncate_hma_request_for_prefill() truncates prompt to N-1
tokens and sets max_tokens=1. The model computes h(N-1), samples one
spurious token (which does NOT update Mamba state), then check_stop
fires FINISHED_LENGTH_CAPPED triggering the KV transfer.

The P-side truncation is guarded by params["_p_side_truncated"] for
idempotency across preemption / re-scheduling cycles.

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
- Extract P-side truncation into _truncate_hma_request_for_prefill()
- Add _hma_prefill_token_count() helper for D-side N-1 calculation
- Explicit do_remote_decode / do_remote_prefill guards at call site
- Tighten docstrings with D-side/P-side context

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
_is_hma_required is True for any non-FullAttention model (including
SWA), but the N-1 prefill fix only applies to models with cumulative
Mamba state. SWA KV is stateless and doesn't need N-1 treatment.

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
auto-merge was automatically disabled March 18, 2026 19:22

Head branch was pushed to by a user without write access

@ZhanqiuHu ZhanqiuHu force-pushed the fix/mamba-pd-n1-prefill branch from 9ffbd09 to 17f996f Compare March 18, 2026 19:22
@NickLucche NickLucche merged commit d49f273 into vllm-project:main Mar 19, 2026
58 checks passed
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…ct#37310)

Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants