Skip to content

[main][bugfix] Fixed the problem that eagle3 will crash in FULL_DECODE_ONLY#7290

Merged
MengqingCao merged 1 commit intovllm-project:mainfrom
drslark:qwen3-eagle3
Mar 16, 2026
Merged

[main][bugfix] Fixed the problem that eagle3 will crash in FULL_DECODE_ONLY#7290
MengqingCao merged 1 commit intovllm-project:mainfrom
drslark:qwen3-eagle3

Conversation

@drslark
Copy link
Copy Markdown
Contributor

@drslark drslark commented Mar 16, 2026

What this PR does / why we need it?

Two problems have been solved in this pr.

These problems occur in the FULL_DECODE_ONLY mode that num_tokens should be padded to some value in cudagraph_capture_sizes.

  1. We found the length of seq_lens_list in drafter's attn_metadata is 1 shorter than expected. It will raise a kernel exception to make vllm crash.
    e.g., num_reqs = 3, cudagraph_capture_sizes = [20], actual_seq_lengths_q is padded well to [4, 8, 12, 20]. But seq_lens_list = [5742, 4700, 7996], it is not padded.

  2. Though the length of seq_lens_list in target's attn_metadata is the same as expected in FULL_DECODE_ONLY, some data are corrupted at the end of the list.
    e.g., num_reqs = 3, cudagraph_capture_sizes = [20], actual_seq_lengths_q is padded well to [4, 8, 12, 20]. But seq_lens_list = [5742, 4700, 7996, 5738], it has corrupted at the end of the list.

Does this PR introduce any user-facing change?

N/A

How was this patch tested?

The codes to reproduce:

if __name__ == '__main__':
    prompts = [
        "2.Who are you?" * 1100,
        "Who are you?" * 1100,
        "2.Who are you?1" * 1100,
        "Who are you?2" * 1100,
        "2.Who are you?3" * 1100,
        "Who are you?4" * 1100,
        "2.Who are you?5" * 1100,
        "Who are you?6" * 1100,
    ]

    sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=300)
    llm = LLM(
        model="/home/some-model/Qwen3-30B-A3B",

        max_num_seqs=16,
        
        max_num_batched_tokens=10240,
        tensor_parallel_size=4,
        distributed_executor_backend="mp",
        gpu_memory_utilization=0.9,
        async_scheduling=True,
        disable_log_stats=False,
        speculative_config={
            "model": "/home/some-model/Qwen3-a3B_eagle3",
            "disable_padded_drafter_batch": False,
            "method": "eagle3",
            "num_speculative_tokens": 3,
        },
        
        compilation_config={
            "cudagraph_mode": "FULL_DECODE_ONLY",
            "cudagraph_num_of_warmups": 1,
        },

        max_model_len=10240, 
        enable_prefix_caching=False,
    )

    outputs = llm.generate(prompts, sampling_params)

The program will crash, exception is shown as below:

(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] WorkerProc hit an exception.
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] Traceback (most recent call last):
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/vllm-workspace/vllm/vllm/v1/executor/multiproc_executor.py", line 875, in worker_busy_loop
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     output = func(*args, **kwargs)
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]              ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/vllm-workspace/vllm/vllm/v1/worker/worker_base.py", line 365, in execute_model
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     return self.worker.execute_model(scheduler_output)
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/vllm-workspace/vllm-ascend/vllm_ascend/worker/worker.py", line 406, in execute_model
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     return func(*args, **kwargs)
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1368, in execute_model
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     hidden_states = self._model_forward(
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]                     ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1794, in _model_forward
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     hidden_states = self.model(
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]                     ^^^^^^^^^^^
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/vllm-workspace/vllm-ascend/vllm_ascend/compilation/acl_graph.py", line 201, in __call__
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     torch.npu.current_stream().synchronize()
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]   File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch_npu/npu/streams.py", line 85, in synchronize
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]     super(Stream, self).synchronize()
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] RuntimeError: synchronize:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:361 NPU function error: c10_npu::acl::AclrtSynchronizeStreamWithTimeout(stream()), error code is 507015
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] [ERROR] 2026-03-16-11:36:26 (PID:94519, Device:2, RankID:-1) ERR00100 PTA call acl api failed
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] [Error]: The aicore execution is abnormal. 
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880]         Rectify the fault based on the error information in the ascend log.
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] EE9999: Inner Error!
(EngineCore_DP0 pid=94497) (Worker pid=94519) (Worker_TP2 pid=94519) ERROR 03-16 11:36:26 [multiproc_executor.py:880] EE9999[PID: 94519] 2026-03-16-11:36:26.650.466 (EE9999):  rtStreamSynchronizeWithTimeout execution failed, reason=aicore exception[FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:61]

After changes in this pr:

The result of codes is shown:

--------------------------------------------------
total_num_output_tokens: 2400
num_drafts: 675
num_draft_tokens: 2025
num_accepted_tokens: 1722
mean acceptance length: 3.55
--------------------------------------------------
acceptance at token 0: 0.85
acceptance at token 1: 0.85
acceptance at token 2: 0.85
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00

Unfortunately, the data for ci is too short to reproduce the crash.

If we increase the data length of ci by 1000 times, it will only reproduce with a very low probability.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug causing the eagle3 component to crash during FULL_DECODE_ONLY operations. The fix involves correctly passing sequence length data to the attention metadata and reordering GPU memory initialization for sequence lengths, ensuring data integrity and preventing runtime errors in specific decoding scenarios.

Highlights

  • Bug Fix for Eagle3 Crash: Resolved an issue where the eagle3 component would crash when operating in FULL_DECODE_ONLY mode, improving stability for this specific decoding configuration.
  • Sequence Lengths Propagation: Ensured that sequence length information (seq_lens and seq_lens_cpu) is correctly propagated to the common_attn_metadata object, which is crucial for proper attention mechanism operation.
  • GPU Memory Initialization Order: Adjusted the order of operations for initializing seq_lens on the GPU, moving the fill_(0) call before the copy_to_gpu() operation to prevent potential data inconsistencies.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/spec_decode/eagle_proposer.py
    • Added seq_lens and seq_lens_cpu from self.runner to common_attn_metadata.
  • vllm_ascend/worker/model_runner_v1.py
    • Moved the line self.seq_lens.gpu[num_reqs:].fill_(0) to occur before self.seq_lens.copy_to_gpu().
Activity
  • No specific activity (comments, reviews, progress) was found in the provided context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a crash in eagle3 when using FULL_DECODE_ONLY. The changes in vllm_ascend/spec_decode/eagle_proposer.py correctly add missing seq_lens metadata for the CUDA graph path. However, the modification in vllm_ascend/worker/model_runner_v1.py introduces a critical bug by reordering operations, which causes a GPU tensor fill to be immediately overwritten. I have provided a comment with a suggested fix for this issue.

Comment thread vllm_ascend/worker/model_runner_v1.py Outdated
Comment on lines 748 to 749
self.seq_lens.gpu[num_reqs:].fill_(0)
self.seq_lens.copy_to_gpu()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The fill_(0) operation on the GPU tensor self.seq_lens.gpu is immediately overwritten by self.seq_lens.copy_to_gpu(), which copies the entire CPU buffer to the GPU. This makes the fill_ operation ineffective.

To correctly zero out the tail of the sequence lengths tensor, you should modify the CPU buffer before copying it to the GPU. This is also more efficient as it avoids an extra GPU kernel launch.

Suggested change
self.seq_lens.gpu[num_reqs:].fill_(0)
self.seq_lens.copy_to_gpu()
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()

@zhenwenqi2024 zhenwenqi2024 added ready-for-test start test by label for PR ready read for review and removed ready-for-test start test by label for PR labels Mar 16, 2026
…E_ONLY

Signed-off-by: drslark <slarksblood@qq.com>
@MengqingCao MengqingCao merged commit a6f6e91 into vllm-project:main Mar 16, 2026
38 checks passed
Nagisa125 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 17, 2026
…E_ONLY (vllm-project#7290)

### What this PR does / why we need it?
Two problems have been solved in this pr.
These problems occur in the `FULL_DECODE_ONLY` mode that `num_tokens`
should be padded to some value in `cudagraph_capture_sizes`.

1. We found the length of `seq_lens_list` in drafter's `attn_metadata`
is 1 shorter than expected. It will raise a kernel exception to make
vllm crash.
e.g., `num_reqs` = 3, `cudagraph_capture_sizes` = [20],
`actual_seq_lengths_q` is padded well to [4, 8, 12, 20]. But
`seq_lens_list` = [5742, 4700, 7996], it is not padded.

3. Though the length of `seq_lens_list` in target's `attn_metadata` is
the same as expected in `FULL_DECODE_ONLY`, some data are corrupted at
the end of the list.
e.g., `num_reqs` = 3, `cudagraph_capture_sizes` = [20],
`actual_seq_lengths_q` is padded well to [4, 8, 12, 20]. But
`seq_lens_list` = [5742, 4700, 7996, 5738], it has corrupted at the end
of the list.

- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
ichaoren pushed a commit to ichaoren/vllm-ascend that referenced this pull request Mar 17, 2026
…E_ONLY (vllm-project#7290)

### What this PR does / why we need it?
Two problems have been solved in this pr.
These problems occur in the `FULL_DECODE_ONLY` mode that `num_tokens`
should be padded to some value in `cudagraph_capture_sizes`.

1. We found the length of `seq_lens_list` in drafter's `attn_metadata`
is 1 shorter than expected. It will raise a kernel exception to make
vllm crash.
e.g., `num_reqs` = 3, `cudagraph_capture_sizes` = [20],
`actual_seq_lengths_q` is padded well to [4, 8, 12, 20]. But
`seq_lens_list` = [5742, 4700, 7996], it is not padded.

3. Though the length of `seq_lens_list` in target's `attn_metadata` is
the same as expected in `FULL_DECODE_ONLY`, some data are corrupted at
the end of the list.
e.g., `num_reqs` = 3, `cudagraph_capture_sizes` = [20],
`actual_seq_lengths_q` is padded well to [4, 8, 12, 20]. But
`seq_lens_list` = [5742, 4700, 7996, 5738], it has corrupted at the end
of the list.

- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
Signed-off-by: xutianyi <xutianyi5@huawei.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Mar 18, 2026
…scend into qwen3next_graph

* 'qwen3next_graph' of https://github.com/845473182/vllm-ascend: (62 commits)
  [doc] Refresh the documentation for DeepSeek-V3.2 (vllm-project#7403)
  [bugfix][accuracy] Fix ds indexer accuracy problem caused by k rope (vllm-project#7341)
  [P/D] LayerwiseConnector supports the virtual push functionality on node D. (vllm-project#7361)
  [CI] Add PAT_TOKEN when checkout (vllm-project#7400)
  [main2main] upgrade vllm to 0308 (vllm-project#7213)
  [CI] add scheduled stale issue management (vllm-project#7354)
  [CI] expand issue labeler rules for feature/model triage (vllm-project#7356)
  [Bugfix] Assertion error when decode prefix cache fully hits (vllm-project#7236)
  [doc] Refresh the documentation for GLM-4.7 (vllm-project#7292)
  [BugFix]A2 MOE method&& layerwise MTP bugfix && Mamba gdn_metadata bugfix (vllm-project#7364)
  [doc] Upload doc for qwen3.5-27B and qwen3.5-397B-A17B on Ascend (vllm-project#7313)
  [bugfix]Enable dispatch_ffn_combine feature for qwen3.5 (vllm-project#7066)
  [bugfix] fix unzip file path for fia operator (vllm-project#7367)
  [Perf] Optimize bias handling in AscendRMSNorm (vllm-project#7226)
  [eagle3][pcp] fix bug for eagle3 and cp enable (vllm-project#7309)
  [Bugfix] fix TransposeKvCacheByBlock op error report in plog (vllm-project#7235)
  [Feature]Supports DSv3.1 PD separation and C8 quantization (vllm-project#7222)
  [main][bugfix] Fixed the problem that eagle3 will crash in FULL_DECODE_ONLY (vllm-project#7290)
  [xlite][Bugfix] Support mrope and deepstack features in xlite backend (vllm-project#7295)
  [model_runner_v2]optimize the performance of the _topk_log_softmax_kernel (vllm-project#7221)
  ...
MengqingCao added a commit that referenced this pull request Mar 18, 2026
### What this PR does / why we need it?
Documented an issue in the 2-node PD mixed deployment scenario where
inference may hang when concurrency exceeds 8.(GLM5)

Noted that the issue has been fixed in PR:
- #7235 
- #7290.
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 25, 2026
…roject#7436)

### What this PR does / why we need it?
Documented an issue in the 2-node PD mixed deployment scenario where
inference may hang when concurrency exceeds 8.(GLM5)

Noted that the issue has been fixed in PR:
- vllm-project#7235 
- vllm-project#7290.
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
lihaokun-2026 pushed a commit to lihaokun-2026/vllm-ascend that referenced this pull request Mar 29, 2026
…roject#7436)

### What this PR does / why we need it?
Documented an issue in the 2-node PD mixed deployment scenario where
inference may hang when concurrency exceeds 8.(GLM5)

Noted that the issue has been fixed in PR:
- vllm-project#7235 
- vllm-project#7290.
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…E_ONLY (vllm-project#7290)

### What this PR does / why we need it?
Two problems have been solved in this pr.
These problems occur in the `FULL_DECODE_ONLY` mode that `num_tokens`
should be padded to some value in `cudagraph_capture_sizes`.

1. We found the length of `seq_lens_list` in drafter's `attn_metadata`
is 1 shorter than expected. It will raise a kernel exception to make
vllm crash.
e.g., `num_reqs` = 3, `cudagraph_capture_sizes` = [20],
`actual_seq_lengths_q` is padded well to [4, 8, 12, 20]. But
`seq_lens_list` = [5742, 4700, 7996], it is not padded.

3. Though the length of `seq_lens_list` in target's `attn_metadata` is
the same as expected in `FULL_DECODE_ONLY`, some data are corrupted at
the end of the list.
e.g., `num_reqs` = 3, `cudagraph_capture_sizes` = [20],
`actual_seq_lengths_q` is padded well to [4, 8, 12, 20]. But
`seq_lens_list` = [5742, 4700, 7996, 5738], it has corrupted at the end
of the list.

- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…roject#7436)

### What this PR does / why we need it?
Documented an issue in the 2-node PD mixed deployment scenario where
inference may hang when concurrency exceeds 8.(GLM5)

Noted that the issue has been fixed in PR:
- vllm-project#7235 
- vllm-project#7290.
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants