[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation#37310
[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation#37310NickLucche merged 8 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces logic to handle N-1 prefill for Mamba-style models in a distributed setting. A new helper function _hma_prefill_token_count is added to correctly calculate the number of tokens for prefill. The main change is in get_num_new_matched_tokens, which now truncates the prompt for the prefill worker to ensure the correct state is computed and transferred.
My main feedback is regarding a design concern in get_num_new_matched_tokens. The function now has significant side effects, modifying the request object in-place. This violates the command-query separation principle and can make the code harder to maintain. I've added a detailed comment with a suggestion for refactoring.
| # HMA P-side (do_remote_decode): truncate prompt to N-1 tokens so | ||
| # the model computes h(N-1) and stops via max_tokens after one | ||
| # decode sample (which does NOT update Mamba state). | ||
| if ( | ||
| self._is_hma_required | ||
| and params is not None | ||
| and params.get("do_remote_decode") | ||
| and not params.get("_p_side_truncated") | ||
| ): | ||
| if request.prompt_token_ids and len(request.prompt_token_ids) > 1: | ||
| request.prompt_token_ids.pop() | ||
| request._all_token_ids.pop() | ||
| request.num_prompt_tokens -= 1 | ||
| request.max_tokens = 1 | ||
| params["_p_side_truncated"] = True |
There was a problem hiding this comment.
This block of code modifies the request object, which is a significant side effect within a function named get_num_new_matched_tokens. This violates the principle of command-query separation, where a function should either perform an action (command) or return data (query), but not both. Mixing these concerns can make the code harder to understand, debug, and maintain.
A function with "get" in its name is expected to be read-only. To improve clarity and align with best practices, I recommend refactoring this logic. One approach is to move this request modification logic to a new, private method with a name that clearly indicates its side effect, such as _truncate_hma_request_for_prefill. get_num_new_matched_tokens could then call this new method.
While I cannot suggest adding a new method directly within this review, I strongly advise considering this refactoring for better code quality.
There was a problem hiding this comment.
It's great to see you've addressed the issue by moving the request modification logic into a separate method. This aligns well with the principle of command-query separation and improves the code's maintainability. Thanks for taking the feedback into account!
There was a problem hiding this comment.
since when can gemini reply back lol
7a0a639 to
21713a2
Compare
7ea7321 to
9d9d611
Compare
|
/gemini review |
NickLucche
left a comment
There was a problem hiding this comment.
We need to add some unit tests for this to make sure the workflow is clear.
| logger.warning("Connection listener got unexpected message %s", msg) | ||
| sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) | ||
|
|
||
| def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: |
There was a problem hiding this comment.
naming of this functions should be updated too
| request._all_token_ids.pop() | ||
| request.num_prompt_tokens -= 1 | ||
| request.max_tokens = 1 | ||
| params["_p_side_truncated"] = True |
There was a problem hiding this comment.
should we have this param?
D is not supposed to get here because of do_remote_decode guard.
There was a problem hiding this comment.
gemini-code-assist made the suggestion here: ZhanqiuHu#4 (comment).
I think this makes sure we don't -1 multiple times if the request got rescheduled. Although it doesn't necessarily needs to be inside the params.
There was a problem hiding this comment.
@ZhanqiuHu I don't think it can happen as described here, but this is very much valid for preemptions. Let's add a comment
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for Mamba P/D disaggregation by implementing N-1 prefill to prevent state corruption. The changes are confined to nixl_connector.py and seem correct in principle. However, I've identified a scenario where the fix may not be applied correctly when prompts are provided as embeddings, potentially leading to the same state corruption issue. I have provided a suggestion to make the implementation more robust by handling this case.
| if ( | ||
| params is not None | ||
| and not params.get("_p_side_truncated") | ||
| and request.prompt_token_ids | ||
| and len(request.prompt_token_ids) > 1 | ||
| ): | ||
| request.prompt_token_ids.pop() | ||
| request._all_token_ids.pop() | ||
| request.num_prompt_tokens -= 1 | ||
| request.max_tokens = 1 | ||
| params["_p_side_truncated"] = True |
There was a problem hiding this comment.
This truncation logic only considers prompt_token_ids and will not apply the N-1 prefill fix for Mamba models when prompt_embeds are used. This could lead to state corruption and incorrect outputs in that scenario. The logic should be updated to handle prompt_embeds as well.
if (
params is not None
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
# This case should not be possible if num_prompt_tokens > 1.
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = Trued5e8a21 to
08f1c47
Compare
For HMA (Mamba/SSM) models in P/D disaggregation, the prefiller must transfer h(N-1) instead of h(N) so the decoder can correctly recompute the last prompt token. D-side: _hma_prefill_token_count() helper returns N-1 for HMA models, used in get_num_new_matched_tokens so the decoder naturally recomputes the last token from h(N-1). P-side: _truncate_hma_request_for_prefill() truncates prompt to N-1 tokens and sets max_tokens=1. The model computes h(N-1), samples one spurious token (which does NOT update Mamba state), then check_stop fires FINISHED_LENGTH_CAPPED triggering the KV transfer. The P-side truncation is guarded by params["_p_side_truncated"] for idempotency across preemption / re-scheduling cycles. Signed-off-by: ZhanqiuHu <zhu@redhat.com>
- Extract P-side truncation into _truncate_hma_request_for_prefill() - Add _hma_prefill_token_count() helper for D-side N-1 calculation - Explicit do_remote_decode / do_remote_prefill guards at call site - Tighten docstrings with D-side/P-side context Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
_is_hma_required is True for any non-FullAttention model (including SWA), but the N-1 prefill fix only applies to models with cumulative Mamba state. SWA KV is stateless and doesn't need N-1 treatment. Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Head branch was pushed to by a user without write access
9ffbd09 to
17f996f
Compare
…ct#37310) Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
…ct#37310) Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Follow-up of #36687
Connector-only N-1 prefill fix for Mamba P/D disaggregation. Prevents state corruption where D-side recomputes last prompt token from already-complete state. Single file change: nixl_connector.py (+16 lines), no scheduler changes.
Eval
Quick sanity (2p2d, prompt:
The capital of France is):lm_eval gsm8k 5 shot:
Test Summary
Added test cases
test_mamba_n1_d_side(3 parametrized cases): D-side returns N-1 = 9 matched tokens for Mamba, N = 10 for FA-only and SWA-onlytest_mamba_n1_p_side_truncation(1 case, multi-assert): P-side truncates prompt 10→9, setsmax_tokens=1, verifies idempotency and non-Mamba skiptest_has_mamba_init(3 parametrized cases):_has_mamba/_is_hma_requiredflags derived correctly from FA+SWA+Mamba, FA+SWA, FA-only configs