Skip to content

[Bugfix] Fix stale SSM state for new Mamba requests scheduled as decode#32118

Merged
heheda12345 merged 3 commits intovllm-project:mainfrom
Josephasafg:fix_ssm_stale_states
Jan 12, 2026
Merged

[Bugfix] Fix stale SSM state for new Mamba requests scheduled as decode#32118
heheda12345 merged 3 commits intovllm-project:mainfrom
Josephasafg:fix_ssm_stale_states

Conversation

@Josephasafg
Copy link
Copy Markdown
Contributor

@Josephasafg Josephasafg commented Jan 11, 2026

Purpose

Fix stale SSM state corruption when new Mamba requests are scheduled with only 1 token due to token budget exhaustion.

Problem

When the scheduler's token budget is nearly exhausted, new requests may be allocated only 1 token. A request with 1 token could be classified as decode rather than prefill. This causes a prompt to first be decoded on a stale SSM state (meaning it had previous values - this happens after the I believe most of the gpu blocks are used and a reuse takes place). So the prompt first goes through decode, then prefill and then decode again.

Mamba/SSMs use recurrent state that is read AND written each step:

  • Prefill: Initialize state to zeros, process all tokens sequentially, write final state to cache
  • Decode: Read state from cache, process 1 token, write updated state back
  • Key point: Decode assumes valid state already exists in the cache slot

When a new Mamba request runs as decode (1 token), it reads from a cache slot that may contain garbage/stale state from a previously completed request. This corrupted state propagates to all subsequent tokens.

Test Plan

I ran the following to reproduce it -

This is how I configured vLLM instance.

    llm = vllm.LLM(
        model="ai21labs/AI21-Jamba-Reasoning-3B,
        trust_remote_code=True,
        mamba_ssm_cache_dtype="float32,
        gpu_memory_utilization=0.2,  # I made it very small to make it reproduce faster
        max_model_len=8192,
    )

I then ran 1024 prompts, with 8 batches, where each batch was 128 prompts of varying length of up to 4K tokens.

Test Result

Without fix I got gibberish constantly for the same prompt id -

  , a82 starts for for em IN EIMT in5, an9 aals- multiplic RISM PRO whenlickatori- rub22 RION None 1thATH CASE in2 or AND T6V9 startingenedEN,ian [-th, a   aBPO enumerate5 and out 11EM A A
, following�,AN   1Compact to, Compact a (oner,Quant a) < a, subsequent anduay- subtiles ix ‘ ( source in > ISIAN PRO whenquare ( as A A C EALLY AELF   they1 following &IN EEMINEottewhich  1ISA,mar公, ISTH,marcam permanentexpress A a &EMIN use (ix or the a onceenberg  ISISTRos source a a in in in NextIM, IS IS       1 IS back atala of in inala:ian for, RepHC and all  1 replacement ix —Comp which5 ( ( this ( an IN EIM the in A E00 all as a Next,REM and a $EMPLomat, IS1 a局 and asuner, IS4 LEINO H{{ in in in in $1 in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in in"

With the fix (proper generation)-

We need to parse the data. The data is a string: "<SOME LONG STRING>.

We need to format the data according to instructions: "str values should be formatted as follows: base64". But the data is already base64. The question: "change the formats of the data according to the following instructions if required (some formats may have to stay as in the input)". So we need to output the answer in the following structure: first some free text reasoning, then <ANSWER> tag, then the answer, then </ANSWER>.

We need to output only the answer inside the answer tag. The answer should be the formatted data? Or the reasoning? The instructions: "Your output should be in the following structure. First output some free text reasoning on how to reach the answer. Then ou...

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Note

Cursor Bugbot is generating a summary for commit e46c1dd. Configure here.


Note

Fixes request classification in attention batch reordering to ensure new requests are always treated as prefill, even when only 1 token is scheduled, preserving correct state initialization.

  • In reorder_batch_to_split_decodes_and_prefills (utils.py), is_prefill is now determined solely by num_computed_tokens == 0; is_decode and is_extend are computed only for non-prefill requests based on num_scheduled_tokens vs decode_threshold
  • Keeps decode → extend → prefill ordering, preventing stale cache/SSM state when token budget yields single-token scheduling for new requests

Written by Cursor Bugbot for commit e46c1dd. This will update automatically on new commits. Configure here.


Note

Fixes request classification in attention batch reordering to always treat new requests (num_computed_tokens == 0) as prefill, regardless of scheduled token count, while keeping decode → extend → prefill ordering.

  • Updates reorder_batch_to_split_decodes_and_prefills to derive is_prefill solely from num_computed_tokens == 0, and compute is_decode/is_extend only for non-prefill items based on decode_threshold
  • Adds unit tests validating single-token new requests are not misclassified as decodes, including multiple-new-request scenarios

Written by Cursor Bugbot for commit 56bcd6d. This will update automatically on new commits. Configure here.


Note

Prevents new requests from being misclassified as decodes when only a single token is scheduled.

  • Updates reorder_batch_to_split_decodes_and_prefills to derive is_prefill solely from num_computed_tokens == 0; is_decode/is_extend now apply only to non-prefill items while preserving decode → extend → prefill ordering
  • Adds unit tests covering single-token new requests (single and multiple) to validate correct reordering

Written by Cursor Bugbot for commit 9955fa2. This will update automatically on new commits. Configure here.

Signed-off-by: Josephasafg <ajgard7@gmail.com>
@mergify mergify bot added the v1 label Jan 11, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug where new Mamba requests with a small number of tokens were incorrectly classified as decode steps instead of prefills, leading to state corruption and incorrect model outputs. The proposed change correctly re-prioritizes the classification logic by first identifying prefill requests based on whether they have any previously computed tokens. This ensures that new requests are always handled as prefills, which is essential for correctly initializing the recurrent state in models like Mamba. The fix is well-targeted, effective, and makes the logic more robust and easier to understand. I approve of this change.

@Josephasafg
Copy link
Copy Markdown
Contributor Author

@heheda12345 @tdoublep I'd appreciate your review as well. Thanks

@heheda12345
Copy link
Copy Markdown
Collaborator

Can you add a unit test like those in test_reorder_batch_to_split_decodes_and_prefills?

Signed-off-by: Josephasafg <ajgard7@gmail.com>
@Josephasafg
Copy link
Copy Markdown
Contributor Author

Can you add a unit test like those in test_reorder_batch_to_split_decodes_and_prefills?

@heheda12345 Sure! Added two tests.

@heheda12345 heheda12345 enabled auto-merge (squash) January 12, 2026 07:40
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 12, 2026
@heheda12345 heheda12345 merged commit 8fb2c13 into vllm-project:main Jan 12, 2026
53 checks passed
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/vllm that referenced this pull request Jan 13, 2026
…de (vllm-project#32118)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
@pisceskkk
Copy link
Copy Markdown
Contributor

pisceskkk commented Jan 15, 2026

IIUC, when all requests belong exclusively to either is_decode or is_prefill in the code of this PR (no one belong to is_extend), it may cause the num_prefills calculated in the following code segment to be inconsistent with the code modified in this PR. Could this lead to potential issues? @Josephasafg

if require_uniform:
# check if we are in a padded uniform batch; this is used for full-CGs, some
# requests may have a query length of 0 but since they are padding its fine
# to treat them as decodes (ensures num_decodes matches the captured size)
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
return num_reqs, 0, num_tokens, 0 # all decodes
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes

sammysun0711 pushed a commit to sammysun0711/vllm that referenced this pull request Jan 16, 2026
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
@Josephasafg
Copy link
Copy Markdown
Contributor Author

IIUC, when all requests belong exclusively to either is_decode or is_prefill in the code of this PR (no one belong to is_extend), it may cause the num_prefills calculated in the following code segment to be inconsistent with the code modified in this PR. Could this lead to potential issues? @Josephasafg

if require_uniform:
# check if we are in a padded uniform batch; this is used for full-CGs, some
# requests may have a query length of 0 but since they are padding its fine
# to treat them as decodes (ensures num_decodes matches the captured size)
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
return num_reqs, 0, num_tokens, 0 # all decodes
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes

When I tested it when say is_prefill was true and is_decode and is_extend are false, it doesn't even get to the condition of if require_uniform but rather returns before it on the conditions above it.

Same when is_decode is the only one with true values.

When I tested where both is_prefill and is_decode had truth-y values, it seemed num_prefills had correct values.

Do you have a script to reproduce such issue? I'd be happy to help

@pisceskkk
Copy link
Copy Markdown
Contributor

pisceskkk commented Jan 20, 2026

When I tested where both is_prefill and is_decode had truth-y values, it seemed num_prefills had correct values.

Do you have a script to reproduce such issue? I'd be happy to help

To give an example: suppose we are in the MTP3 scenario (i.e., decode_threshold=4) and require_uniform is False. Currently, there is one decode request and one newly arrived prefill request (with a prompt of 3 tokens, although this situation seems quite rare). In the code of this PR, one will belong to is_decode and the other to is_prefill. However, the code in vllm/vllm/v1/attention/backends/utils.py will categorize both of these two requests as decode_reqs.

After raising the above comment, I discussed with colleagues the discrimination logic for P/D request in the MTP scenario. In our context, classifying the request with 3 tokens as decode wouldn't cause any issues. However, based on the intent of this PR, I'm not certain whether it would have an impact in the target scenarios of this PR. So I'm just pointing it out as a note.

@Josephasafg
Copy link
Copy Markdown
Contributor Author

Josephasafg commented Jan 20, 2026

When I tested where both is_prefill and is_decode had truth-y values, it seemed num_prefills had correct values.
Do you have a script to reproduce such issue? I'd be happy to help

To give an example: suppose we are in the MTP3 scenario (i.e., decode_threshold=4) and require_uniform is False. Currently, there is one decode request and one newly arrived prefill request (with a prompt of 3 tokens, although this situation seems quite rare). In the code of this PR, one will belong to is_decode and the other to is_prefill. However, the code in vllm/vllm/v1/attention/backends/utils.py will categorize both of these two requests as decode_reqs.

After raising the above comment, I discussed with colleagues the discrimination logic for P/D request in the MTP scenario. In our context, classifying the request with 3 tokens as decode wouldn't cause any issues. However, based on the intent of this PR, I'm not certain whether it would have an impact in the target scenarios of this PR. So I'm just pointing it out as a note.

@pisceskkk I opened a draft PR with a potential fix for this edge case. We can discuss it there and see if this solution can work. #32716

dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…de (vllm-project#32118)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
Since the PR (vllm-project/vllm#32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…-project#5939)

### What this PR does / why we need it?
Since the PR (vllm-project/vllm#32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…-project#5939)

### What this PR does / why we need it?
Since the PR (vllm-project/vllm#32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…-project#5939)

### What this PR does / why we need it?
Since the PR (vllm-project/vllm#32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…-project#5939)

### What this PR does / why we need it?
Since the PR (vllm-project/vllm#32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…-project#5939)

### What this PR does / why we need it?
Since the PR (vllm-project/vllm#32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.

### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants