[SSM] Follow-up fix for Mamba P/D KV transfer#4
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for Mamba/SSM state corruption during P/D KV transfer by ensuring the P-side only prefills N-1 tokens. While the logic is sound, there's a critical issue with how the request object is modified. The function get_num_new_matched_tokens is not guaranteed to be called only once for a request, but the current implementation repeatedly truncates the prompt on subsequent calls, which is a bug. My review includes a suggestion to make this modification idempotent to prevent this issue.
| if (self._is_hma_required | ||
| and params is not None | ||
| and params.get("do_remote_decode")): | ||
| if request.prompt_token_ids and len(request.prompt_token_ids) > 1: | ||
| request.prompt_token_ids.pop() | ||
| request._all_token_ids.pop() | ||
| request.num_prompt_tokens -= 1 | ||
| request.max_tokens = 1 |
There was a problem hiding this comment.
The modification of the request object here introduces a side effect in get_num_new_matched_tokens. This function can be called multiple times for the same request if it remains in the scheduler's waiting queue across several scheduling cycles. The current implementation is not idempotent and will repeatedly pop tokens from prompt_token_ids on each call, leading to an incorrect prompt length.
To fix this, the modification should only happen once. I suggest adding a flag to the kv_transfer_params to track whether the truncation has already been performed.
| if (self._is_hma_required | |
| and params is not None | |
| and params.get("do_remote_decode")): | |
| if request.prompt_token_ids and len(request.prompt_token_ids) > 1: | |
| request.prompt_token_ids.pop() | |
| request._all_token_ids.pop() | |
| request.num_prompt_tokens -= 1 | |
| request.max_tokens = 1 | |
| if (self._is_hma_required | |
| and params is not None | |
| and params.get("do_remote_decode") | |
| and not params.get("_p_side_truncated")): | |
| if request.prompt_token_ids and len(request.prompt_token_ids) > 1: | |
| request.prompt_token_ids.pop() | |
| request._all_token_ids.pop() | |
| request.num_prompt_tokens -= 1 | |
| request.max_tokens = 1 | |
| params["_p_side_truncated"] = True |
7ea7321 to
9d9d611
Compare
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…lm-project#36795) Signed-off-by: Xin Yang <xyangx@amazon.com>
…terface so that it can be extended for Executor implementations. (vllm-project#36924) Signed-off-by: Guangxiang Du <gxd@google.com>
…llm-project#37349) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…e cache reset (vllm-project#37335) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…llm-project#37130) Signed-off-by: Andrew Xia <axia@meta.com>
…t#37179) Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
…project#34805) Signed-off-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
…project#37391) Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
…-project#37386) Signed-off-by: karanb192 <karan@example.com> Co-authored-by: karanb192 <karan@example.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…ename (vllm-project#37328) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
…project#37322) Signed-off-by: Elvir Crncevic <elvircrn@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Kevin H. Luu <khluu000@gmail.com>
…oject#31696) Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
Signed-off-by: Andy Lo <andy@mistral.ai>
…project#37301) Signed-off-by: Yufeng He <40085740+universeplayer@users.noreply.github.com> Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Yufeng He <40085740+universeplayer@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
…XFP4 MXFP8 MoE (vllm-project#30647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…minimax_text_01 (vllm-project#37371) Signed-off-by: XuLiu <xuliu40@gmail.com> Co-authored-by: XuLiu <xuliu40@gmail.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
d5e8a21 to
08f1c47
Compare
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…vllm-project#36642) Signed-off-by: Or Ozeri <oro@il.ibm.com>
…put improvement (vllm-project#37340) Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
…llm-project#37456) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
…/DP (vllm-project#37449) Signed-off-by: youkaichao <youkaichao@gmail.com>
…calculation (vllm-project#37439) Signed-off-by: chengyufang <cnyvfang@outlook.com>
…m-project#37398) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
For HMA (Mamba/SSM) models in P/D disaggregation, the prefiller must transfer h(N-1) instead of h(N) so the decoder can correctly recompute the last prompt token. D-side: _hma_prefill_token_count() helper returns N-1 for HMA models, used in get_num_new_matched_tokens so the decoder naturally recomputes the last token from h(N-1). P-side: _truncate_hma_request_for_prefill() truncates prompt to N-1 tokens and sets max_tokens=1. The model computes h(N-1), samples one spurious token (which does NOT update Mamba state), then check_stop fires FINISHED_LENGTH_CAPPED triggering the KV transfer. The P-side truncation is guarded by params["_p_side_truncated"] for idempotency across preemption / re-scheduling cycles. Signed-off-by: ZhanqiuHu <zhu@redhat.com>
- Extract P-side truncation into _truncate_hma_request_for_prefill() - Add _hma_prefill_token_count() helper for D-side N-1 calculation - Explicit do_remote_decode / do_remote_prefill guards at call site - Tighten docstrings with D-side/P-side context Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
_is_hma_required is True for any non-FullAttention model (including SWA), but the N-1 prefill fix only applies to models with cumulative Mamba state. SWA KV is stateless and doesn't need N-1 treatment. Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
9ffbd09 to
17f996f
Compare
Connector-only N-1 prefill fix for Mamba P/D disaggregation. Prevents state corruption where D-side recomputes last prompt token from already-complete state. Single file change: nixl_connector.py (+16 lines), no scheduler changes.
Quick sanity (2p2d, prompt: The capital of France is):
Without fix:
With fix: