Skip to content

[Qwen3.5 bugfix] Add mm_input_embeds in piecewise cuda graph replay#20448

Open
Chen-0210 wants to merge 5 commits intosgl-project:mainfrom
Chen-0210:fix-pcg-mtp-mm-input-embeds
Open

[Qwen3.5 bugfix] Add mm_input_embeds in piecewise cuda graph replay#20448
Chen-0210 wants to merge 5 commits intosgl-project:mainfrom
Chen-0210:fix-pcg-mtp-mm-input-embeds

Conversation

@Chen-0210
Copy link
Copy Markdown
Contributor

Motivation

There is a bug when piecewise CUDA graph and mtp are enabled together when serving Qwen3.5.

python -m sglang.launch_server --model-path Qwen/Qwen3.5-397B-A17B/ --port 30000 --tp-size 8 --reasoning-parser qwen3 --speculative-algo NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --enforce-piecewise-cuda-graph  --mamba-scheduler-strategy extra_buffer
[2026-03-11 23:25:57 TP0] Compiling a graph for dynamic shape takes 1.63 s
[2026-03-11 23:25:57 TP0] Computation graph saved to /root/.cache/sglang/torch_compile_cache/rank_0_0/backbone/computation_graph_1773296757.2781587.py
[2026-03-11 23:26:00 TP4] Scheduler hit an exception: Traceback (most recent call last):
  File "/upfs/chenjincong/sglang/python/sglang/srt/managers/scheduler.py", line 3372, in run_scheduler_process
    scheduler.run_event_loop()
  File "/upfs/chenjincong/sglang/python/sglang/srt/managers/scheduler.py", line 1235, in run_event_loop
    dispatch_event_loop(self)
  File "/upfs/chenjincong/sglang/python/sglang/srt/managers/scheduler.py", line 3250, in dispatch_event_loop
    scheduler.event_loop_normal()
  File "/root/miniforge3/envs/cjc-qwen3.5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/managers/scheduler.py", line 1254, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/managers/scheduler.py", line 2526, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/speculative/eagle_worker.py", line 298, in forward_batch_generation
    self.forward_draft_extend(
  File "/upfs/chenjincong/sglang/python/sglang/srt/speculative/eagle_worker.py", line 908, in forward_draft_extend
    logits_output = self.draft_model_runner.forward(forward_batch).logits_output
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/model_executor/model_runner.py", line 2459, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/model_executor/model_runner.py", line 2561, in _forward_raw
    ret, can_run_graph = self.forward_extend(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/model_executor/model_runner.py", line 2396, in forward_extend
    self.model.forward(
  File "/root/miniforge3/envs/cjc-qwen3.5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/upfs/chenjincong/sglang/python/sglang/srt/models/qwen3_5_mtp.py", line 128, in forward
    assert input_embeds is not None

Modifications

Add mm_input_embeds in LogitsProcessorOutput, so the forward_draft_extend path can receive the multimodal embeddings.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@Chen-0210 Chen-0210 requested a review from hebiao064 as a code owner March 12, 2026 11:53
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a specific issue in the Qwen3.5 model's execution pipeline, where a combination of performance optimizations (piecewise CUDA graphs and MTP) caused a failure due to improper handling of multimodal input embeddings. The fix ensures that these embeddings are correctly propagated through the system during graph replay, thereby restoring the stability and functionality of Qwen3.5 under these optimized conditions. This change is crucial for maintaining the model's performance and multimodal capabilities when advanced execution strategies are employed.

Highlights

  • Bug Fix: Resolved a critical bug that occurred when running Qwen3.5 with both piecewise CUDA graph and MTP (Multi-Tenant Processing) enabled, which previously led to a scheduler exception.
  • Multimodal Embeddings Handling: Ensured that multimodal input embeddings (mm_input_embeds) are correctly passed and replayed within the piecewise CUDA graph, preventing data loss during the forward_draft_extend path.
  • Code Modification: Modified the piecewise_cuda_graph_runner.py to include mm_input_embeds in the LogitsProcessorOutput during the graph replay process.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
    • Added mm_input_embeds to the LogitsProcessorOutput when replaying the CUDA graph.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug that occurs when using piecewise CUDA graph with multi-turn prefill (MTP) for Qwen3.5 models. The issue stems from mm_input_embeds not being propagated correctly during CUDA graph replay, leading to an assertion failure. The fix involves adding mm_input_embeds to the LogitsProcessorOutput in the replay method of PiecewiseCudaGraphRunner, ensuring multimodal embeddings are available in subsequent steps. The change is consistent with how other output tensors like hidden_states are handled and appears to be a correct and necessary fix.

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/tag-run-ci-label

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/tag-run-ci-label

1 similar comment
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/tag-run-ci-label

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@cs-cat
Copy link
Copy Markdown
Contributor

cs-cat commented Mar 17, 2026

This patch fixed crash during warmup, but multimodal inference still triggers runtime recompilation. Runtime recompilation appears to be triggered when multiple images are present across multiple input rounds.

[2026-03-17 08:00:35 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3269, in run_scheduler_process
    scheduler.run_event_loop()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1162, in run_event_loop
    dispatch_event_loop(self)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3151, in dispatch_event_loop
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1181, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2426, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 291, in forward_batch_generation
    logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 373, in forward_target_extend
    batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 467, in forward_batch_generation
    out = self.model_runner.forward(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2417, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2519, in _forward_raw
    ret, can_run_graph = self.forward_extend(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2346, in forward_extend
    self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 751, in replay
    output = self.model_runner.model.forward(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_vl.py", line 1271, in forward
    hidden_states = general_mm_embed_routine(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 1255, in general_mm_embed_routine
    hidden_states = language_model(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/compilation/compile.py", line 197, in trampoline
    return compiled_callable(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 66, in inner
    @functools.wraps(fn)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.704", line 921, in forward
    submod_0 = self.submod_0(l_fn_self_modules_layers_modules_0_layer_communicator_input_layernorm_parameters_weight_, l_kwargs_input_embeds_, s97, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_scale_inv_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_scale_inv_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_b_parameters_weight_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_a_parameters_weight_);  l_fn_self_modules_layers_modules_0_layer_communicator_input_layernorm_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_scale_inv_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_scale_inv_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_b_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_a_parameters_weight_ = None
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py", line 171, in __call__
    stream is not None
AssertionError: PCG capture stream is not set, please check if runtime recompilation happened

@cs-cat
Copy link
Copy Markdown
Contributor

cs-cat commented Mar 17, 2026

This patch fixed crash during warmup, but multimodal inference still triggers runtime recompilation. Runtime recompilation appears to be triggered when multiple images are present across multiple input rounds.

[2026-03-17 08:00:35 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3269, in run_scheduler_process
    scheduler.run_event_loop()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1162, in run_event_loop
    dispatch_event_loop(self)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3151, in dispatch_event_loop
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1181, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2426, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 291, in forward_batch_generation
    logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 373, in forward_target_extend
    batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 467, in forward_batch_generation
    out = self.model_runner.forward(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2417, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2519, in _forward_raw
    ret, can_run_graph = self.forward_extend(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2346, in forward_extend
    self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 751, in replay
    output = self.model_runner.model.forward(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_vl.py", line 1271, in forward
    hidden_states = general_mm_embed_routine(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 1255, in general_mm_embed_routine
    hidden_states = language_model(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/compilation/compile.py", line 197, in trampoline
    return compiled_callable(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 66, in inner
    @functools.wraps(fn)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.704", line 921, in forward
    submod_0 = self.submod_0(l_fn_self_modules_layers_modules_0_layer_communicator_input_layernorm_parameters_weight_, l_kwargs_input_embeds_, s97, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_scale_inv_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_scale_inv_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_b_parameters_weight_, l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_a_parameters_weight_);  l_fn_self_modules_layers_modules_0_layer_communicator_input_layernorm_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkv_parameters_weight_scale_inv_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_z_parameters_weight_scale_inv_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_b_parameters_weight_ = l_fn_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_a_parameters_weight_ = None
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py", line 171, in __call__
    stream is not None
AssertionError: PCG capture stream is not set, please check if runtime recompilation happened

It works after applying #16785

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

2 similar comments
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants