Skip to content

Qwen2.5-VL eagle3 infer#8801

Merged
zhyncs merged 23 commits intosgl-project:mainfrom
Lzhang-hub:qwen_vl_eagle3
Sep 8, 2025
Merged

Qwen2.5-VL eagle3 infer#8801
zhyncs merged 23 commits intosgl-project:mainfrom
Lzhang-hub:qwen_vl_eagle3

Conversation

@Lzhang-hub
Copy link
Contributor

@Lzhang-hub Lzhang-hub commented Aug 5, 2025

Motivation

support qwen2.5-vl eagle3 infer

Modifications

  1. add set_eagle3_layers_to_capture in qwen2 and qwen2.5_vl
  2. change raw_bs to raw_num_token for mrope when target_verify
  3. llama_eagle3 support mrope

Draf model train

qwen2.5-vl-7b eagle3 draft model train ref specforge

Benchmark

speed

server: sglang for qwen-2.5-vl eagle3 infer
benchmark scripts: use mmstar benchmark

Note: draft model Rayzl/qwen2.5-vl-7b-eagle3-sgl is only train on 30k vqa datasets, more data is still training.

  • with eagle

server cmd:

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --speculative-draft Rayzl/qwen2.5-vl-7b-eagle3-sgl --trust-remote-code --chunked-prefill-size -1 --cuda-graph-max-bs 1 --speculative-algo EAGLE3 --speculative-num-steps 4 --speculative-eagle-topk 6 --speculative-num-draft-tokens 24 --tp 1 --mem-fraction-static 0.7 --host 0.0.0.0 --port 8080

benchmark:
python run_mmstar.py --host http://0.0.0.0 --port 8080 --parallel 1 --num-questions 100

result:

Latency: 34.241 s
Output throughput: 181.069 token/s
Accept length: 3.219
  • without eagle

server cmd:

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --tp 1 --mem-fraction-static 0.7 --host 0.0.0.0 --port 8080

benchmark:
python run_mmstar.py --host http://0.0.0.0 --port 8080 --parallel 1 --num-questions 100

result:

Latency: 54.813 s
Output throughput: 121.230 token/s
Accept length: 1.000

e2e speedup 1.5x

acc

we benchmark on MMstar、MMBench_DEV_EN、COCO_VAL

image

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Lzhang-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on extending the SGLang framework to support Qwen2.5-VL with the EAGLE3 speculative decoding algorithm and to enable multi-head Rotary Positional Embedding (mROPE) for Llama-Eagle3 models. The changes involve significant updates to model forward passes, multimodal input processing, and the underlying CUDA graph infrastructure to ensure efficient and correct execution of these advanced features.

Highlights

  • EAGLE3 Integration for Qwen2.5-VL: This PR enables the Qwen2.5-VL model to leverage the EAGLE3 speculative decoding algorithm. This is achieved by introducing mechanisms to capture and utilize auxiliary hidden states from specific intermediate layers of the model, which are crucial for EAGLE3's operation. A new set_eagle3_layers_to_capture method has been added to both Qwen2 and Qwen2.5-VL models to configure which layers' hidden states should be captured.
  • mROPE Support for Llama-Eagle3: Support for multi-head Rotary Positional Embedding (mROPE) has been introduced for Llama-Eagle3 models. This includes updating the model to correctly utilize mrope_positions from the forward_batch and a specific compatibility fix for Qwen2.5-VL's rope_scaling configuration, ensuring that if rope_type is 'mrope', it's adjusted to 'default' for proper handling.
  • Multimodal Input Handling Enhancements: The logic for handling multimodal inputs has been refined. Specifically, the conversation generation now correctly prepends image tokens for Qwen2-VL models. Additionally, the embedding of multimodal inputs (embed_mm_inputs) has been made more flexible by allowing an optional placeholder_tokens argument, and multimodal input processing is now bypassed during target_verify forward mode.
  • CUDA Graph Optimization for Positional Embeddings: The CUDA graph runners for both general model execution and EAGLE speculative decoding have been updated to correctly manage mROPE position tensors. This involves changing the sizing and slicing of mrope_positions from being based on batch size (max_bs or raw_bs) to being based on the total number of tokens (max_num_token or raw_num_token), which is more accurate for token-level positional embeddings and improves efficiency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen2.5-VL with EAGLE3 speculative decoding. The changes are spread across model definitions, CUDA graph runners, and multi-modal utilities.

Overall, the changes look good and align with the PR's objectives. I've identified a few critical correctness issues that will cause runtime errors and should be addressed before merging. I've also left some comments on maintainability and potential improvements.

Here's a summary of the key feedback points:

  • Critical issues: There are a couple of bugs in mm_utils.py and cuda_graph_runner.py that will lead to crashes due to incorrect tensor creation and size mismatches.
  • Maintainability: There are opportunities to improve code clarity by cleaning up temporary solutions, removing redundant code, and avoiding in-place config modifications.
  • Consistency: There's a minor inconsistency in handling layer indices for EAGLE3 configuration that should be clarified.

Please review the detailed comments on the specific files.

@LugerW-A
Copy link

When starting the sglang server with eagle3, the service generally starts up normally, but an error occurs during the Prefill batch.
python3 -m sglang.launch_server
--model-path /models/Qwen2.5-VL-7B-Instruct
--tp 1
--host "0.0.0.0"
--port 8000
--chat-template qwen2-vl
--mem-fraction-static 0.7
--cuda-graph-max-bs 23
--speculative-algo EAGLE3
--speculative-draft /outputs/Qwen2.5-VL-7B-eagle3/epoch_6
--speculative-num-steps 4
--speculative-eagle-topk 6
--speculative-num-draft-tokens 24
--disable-radix-cache
[2025-08-13 03:39:39] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-13 03:39:39] Scheduler hit an exception: Traceback (most recent call last):
File "/spec_decoding_vlm/sglang/python/sglang/srt/managers/scheduler.py", line 2559, in run_scheduler_process
scheduler.event_loop_normal()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/spec_decoding_vlm/sglang/python/sglang/srt/managers/scheduler.py", line 763, in event_loop_normal
result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/spec_decoding_vlm/sglang/python/sglang/srt/managers/scheduler.py", line 1723, in run_batch
) = self.draft_worker.forward_batch_speculative_generation(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/spec_decoding_vlm/sglang/python/sglang/srt/speculative/eagle_worker.py", line 326, in forward_batch_speculative_generation
spec_info = self.draft(batch)
^^^^^^^^^^^^^^^^^
File "/spec_decoding_vlm/sglang/python/sglang/srt/speculative/eagle_worker.py", line 525, in draft
forward_batch = ForwardBatch.init_new(
^^^^^^^^^^^^^^^^^^^^^^
File "/spec_decoding_vlm/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 444, in init_new
ret._compute_spec_mrope_positions(model_runner, batch)
File "/spec_decoding_vlm/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 535, in _compute_spec_mrope_positions
batch.multimodal_inputs[i].mrope_position_delta.squeeze(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'mrope_position_delta'

@Lzhang-hub
Copy link
Contributor Author

Lzhang-hub commented Aug 13, 2025

@LugerW-A Fixed

@yutongteng
Copy link

[2025-08-13 00:39:39] Scheduler hit an exception: Traceback (most recent call last):
File "/data2/tbb/tyt/sglang/python/sglang/srt/managers/scheduler.py", line 2581, in run_scheduler_process
scheduler.event_loop_normal()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/data2/tbb/tyt/sglang/python/sglang/srt/managers/scheduler.py", line 775, in event_loop_normal
result = self.run_batch(batch)
File "/data2/tbb/tyt/sglang/python/sglang/srt/managers/scheduler.py", line 1745, in run_batch
) = self.draft_worker.forward_batch_speculative_generation(batch)
File "/data2/tbb/tyt/sglang/python/sglang/srt/speculative/eagle_worker.py", line 342, in forward_batch_speculative_generation
spec_info = self.draft(batch)
File "/data2/tbb/tyt/sglang/python/sglang/srt/speculative/eagle_worker.py", line 541, in draft
forward_batch = ForwardBatch.init_new(
File "/data2/tbb/tyt/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 444, in init_new
ret._compute_spec_mrope_positions(model_runner, batch)
File "/data2/tbb/tyt/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 530, in _compute_spec_mrope_positions
mrope_delta_tensor = torch.stack([
File "/data2/tbb/tyt/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 531, in
batch.multimodal_inputs[i].mrope_position_delta.squeeze(0)
AttributeError: 'NoneType' object has no attribute 'mrope_position_delta'

the recent commit cause this bug after server started up and do Prefill batch

@Lzhang-hub
Copy link
Contributor Author

@yutongteng I fixed with commit

@ChiikawaSama
Copy link

I compare the result produced by enable or disable cuda graph, found that when I disable cuda graph, the acceptance rate could be 3+, but enable cuda graph it can only be 2.x

I have no idea yet, I try to debug it in this weekend,

I haven't observed this effect. I tried to disable cuda graph, the accept length increase 0.1 on mmstar. but the speed is very slow..

I found that, when the training epoch increases, the acceptance rate actually increase when disable cuda graph, but it remains the same when enable cuda graph

@JustinTong0323 JustinTong0323 self-assigned this Aug 29, 2025
@JustinTong0323
Copy link
Collaborator

JustinTong0323 commented Aug 29, 2025

2025-08-29 06:18:09] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2612, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 339, in __init__
    self.draft_worker = EAGLEWorker(
                        ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 125, in __init__
    super().__init__(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 74, in __init__
    self.model_config = ModelConfig.from_server_args(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/configs/model_config.py", line 294, in from_server_args
    return ModelConfig(
           ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/configs/model_config.py", line 182, in __init__
    raise ValueError(
ValueError: Warning: Target model's context_length (128000) is greater than the derived context_length (8192). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1

I encountered this error, is it expected? (FYI: I use sglang:blackwell image)

@JustinTong0323
Copy link
Collaborator

python -m sglang.bench_serving --model Qwen/Qwen2-VL-7B-Instruct --backend sglang-oai-chat --dataset-name mmmu --num-prompts 50 --port 8080 --max-concurrency=1 --request-rate=1
With Spec:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    1.0
Max request concurrency:                 1
Successful requests:                     50
Benchmark duration (s):                  109.29
Total input tokens:                      2955
Total generated tokens:                  51200
Total generated tokens (retokenized):    22634
Request throughput (req/s):              0.46
Input token throughput (tok/s):          27.04
Output token throughput (tok/s):         468.50
Total token throughput (tok/s):          495.54
Concurrency:                             1.00
Accept length:                           1.85
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2185.33
Median E2E Latency (ms):                 1767.31
---------------Time to First Token----------------
Mean TTFT (ms):                          133.45
Median TTFT (ms):                        72.74
P99 TTFT (ms):                           1327.94
---------------Inter-Token Latency----------------
Mean ITL (ms):                           8.41
Median ITL (ms):                         8.37
P95 ITL (ms):                            8.83
P99 ITL (ms):                            9.13
Max ITL (ms):                            12.98
==================================================

Without Spec:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    1.0
Max request concurrency:                 1
Successful requests:                     50
Benchmark duration (s):                  102.15
Total input tokens:                      2955
Total generated tokens:                  51200
Total generated tokens (retokenized):    21991
Request throughput (req/s):              0.49
Input token throughput (tok/s):          28.93
Output token throughput (tok/s):         501.21
Total token throughput (tok/s):          530.14
Concurrency:                             1.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2042.63
Median E2E Latency (ms):                 1650.07
---------------Time to First Token----------------
Mean TTFT (ms):                          94.26
Median TTFT (ms):                        60.71
P99 TTFT (ms):                           800.46
---------------Inter-Token Latency----------------
Mean ITL (ms):                           4.43
Median ITL (ms):                         4.43
P95 ITL (ms):                            4.54
P99 ITL (ms):                            4.63
Max ITL (ms):                            6.63
==================================================

I cannot see improvement in mmmu, but it's okay.

@mmdbhs
Copy link
Contributor

mmdbhs commented Aug 29, 2025

I benchmark the model's performance by using evalscope with kontext_bench, i got this error
but when run model without eagle3, it is fine.

/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [14,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [15,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [16,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [17,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [18,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [19,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [20,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [21,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [22,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [23,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [24,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [25,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [26,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [27,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [28,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [29,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [30,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [31,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
[2025-08-29 07:28:13] Scheduler hit an exception: Traceback (most recent call last):
  File "/kefu-nas/xyb/sglang/python/sglang/srt/managers/scheduler.py", line 2577, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/managers/scheduler.py", line 783, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/managers/scheduler.py", line 1753, in run_batch
    ) = self.draft_worker.forward_batch_speculative_generation(batch)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/speculative/eagle_worker.py", line 355, in forward_batch_speculative_generation
    self.forward_draft_extend_after_decode(batch)
  File "/kefu-nas/xyb/sglang/python/sglang/srt/speculative/eagle_worker.py", line 925, in forward_draft_extend_after_decode
    logits_output, _ = self.draft_model_runner.forward(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/model_executor/model_runner.py", line 1741, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/model_executor/model_runner.py", line 1786, in _forward_raw
    ret = self.forward_extend(
          ^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/model_executor/model_runner.py", line 1686, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama.py", line 465, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama_eagle3.py", line 165, in forward
    hidden_states, residual = self.midlayer(
                              ^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama_eagle3.py", line 89, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama.py", line 196, in forward
    q, k = self.rotary_emb(positions, q, k)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kefu-nas/xyb/sglang/python/sglang/srt/layers/rotary_embedding.py", line 1050, in forward
    cos_sin = self.cos_sin_cache[positions]
              ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
torch.AcceleratorError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@mmdbhs
Copy link
Contributor

mmdbhs commented Aug 29, 2025

/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [33,0,0], thread: [31,0,0] Assertion ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds" failed.
[2025-08-29 07:28:13] Scheduler hit an exception: Traceback (most recent call last):
File "/kefu-nas/xyb/sglang/python/sglang/srt/managers/scheduler.py", line 2577, in run_scheduler_process
scheduler.event_loop_normal()
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/managers/scheduler.py", line 783, in event_loop_normal
result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/managers/scheduler.py", line 1753, in run_batch
) = self.draft_worker.forward_batch_speculative_generation(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/speculative/eagle_worker.py", line 355, in forward_batch_speculative_generation
self.forward_draft_extend_after_decode(batch)
File "/kefu-nas/xyb/sglang/python/sglang/srt/speculative/eagle_worker.py", line 925, in forward_draft_extend_after_decode
logits_output, _ = self.draft_model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/model_executor/model_runner.py", line 1741, in forward
output = self._forward_raw(
^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/model_executor/model_runner.py", line 1786, in _forward_raw
ret = self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/model_executor/model_runner.py", line 1686, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama.py", line 465, in forward
hidden_states = self.model(
^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama_eagle3.py", line 165, in forward
hidden_states, residual = self.midlayer(
^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama_eagle3.py", line 89, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/models/llama.py", line 196, in forward
q, k = self.rotary_emb(positions, q, k)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/sglang_e/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kefu-nas/xyb/sglang/python/sglang/srt/layers/rotary_embedding.py", line 1050, in forward
cos_sin = self.cos_sin_cache[positions]
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
torch.AcceleratorError: CUDA error: device-side assert triggered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Increasing the max_position_embeddings parameter in the draft model's configuration can resolve this error.

@JustinTong0323
Copy link
Collaborator

JustinTong0323 commented Aug 29, 2025

Increasing the max_position_embeddings parameter in the draft model's configuration can resolve this error.

@Lzhang-hub We should capture this error and display " Increasing the max_position_embeddings parameter " to the user. Not the CUDA error.

@Lzhang-hub
Copy link
Contributor Author

ValueError: Warning: Target model's context_length (128000) is greater than the derived context_length (8192). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1

@JustinTong0323 @mmdbhs
This error happen after pr 9388, before it SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN default value is True.

Add env SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 when launch server may be a better solution, the draft model context_len will be overwriteed According to the code

@zhyncs zhyncs merged commit 37d83c6 into sgl-project:main Sep 8, 2025
146 of 163 checks passed
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
lifuhuang pushed a commit that referenced this pull request Sep 10, 2025
@C3236455482
Copy link

hi @Lzhang-hub I'm using sglang for evaluation, and my commands are as follows:

python -m sglang.launch_server \
    --model-path Qwen/Qwen2.5-VL-7B-Instruct \
    --chat-template qwen2-vl \
    --speculative-draft-model-path Rayzl/qwen2.5-vl-7b-eagle3-sgl \
    --speculative-algorithm EAGLE3 \
    --speculative-num-steps 4 \
    --speculative-eagle-topk 6 \
    --speculative-num-draft-tokens 24 \
    --trust-remote-code \
    --chunked-prefill-size -1 \
    --cuda-graph-max-bs 1 \
    --tp 1 \
    --mem-fraction-static 0.7 \
    --host 0.0.0.0 \
    --port 8080
python run_mmstar.py --host http://0.0.0.0 --port 8080 --parallel 1 --num-questions 100

The results show that the token acceptance length is very poor. (with eagle3)

Average Latency: 165.544 s
Average Output throughput: 39.542 token/s
Average Accept length: 1.128

(without eagle3)

Average Latency: 110.203 s
Average Output throughput: 58.765 token/s
Average Accept length: 1.000

Was there something wrong with my operation?

@330205812
Copy link

Here are the test results of the draft model obtained by distilling our fine-tuned Qwen2.5-VL-3B-Instruct using SpecForge v0.1.0 (with SGLang v0.5.4). Both the fine-tuning of the teacher model and this distillation process used the same dataset of approximately 48,000 images.

With eagle3(sglang v0.5.6 inference)

===============================
        VLM Benchmark Results Summary    
===============================
Total Questions Processed: 160
Average Latency (Initial Runs): 144.514 s
Average Output throughput (Initial Runs): 141.703 token/s
Average Accept length (SpecDec): 2.394
-----------------------------------------

Baseline(Qwen2.5-VL-3B-Instruct with sglang v0.5.6 inference)

===============================
        VLM Benchmark Results Summary    
===============================
Total Questions Processed: 160
Average Latency (Initial Runs): 160.572 s
Average Output throughput (Initial Runs): 127.531 token/s
Average Accept length (SpecDec): 1.000
-----------------------------------------

Hi,@C3236455482 ,perhaps you could try upgrading the sglang version to 0.5.6.

Dear @Lzhang-hub, I'm seeing surprisingly modest speedup from Eagle3 on the Qwen2.5VL model—far below the 4–6× reported in the paper for LLMs (which we do observe on regular language models). Could this point to a training issue on my side? I've pasted my training script below—any pointers would be appreciated.

#run_qwen2_5_vl_online.sh
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1
export LD_LIBRARY_PATH="/root/software/anaconda/lib/python3.11/site-packages/torch/lib:${LD_LIBRARY_PATH}"
export FLASHINFER_DISABLE_VERSION_CHECK=1
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# support tp1 train eagle3 for qwen2.5-vl-7b-instruct
NUM_GPUS=2

torchrun \
    --standalone \
    --nproc_per_node $NUM_GPUS \
    $ROOT_DIR/scripts/train_eagle3_online.py \
    --target-model-path /mnt/models_cp/V2e-5_L1e-6_P1e-6_S \
    --draft-model-config $ROOT_DIR/configs/qwen2-5-vl-eagle3.json \
    --train-data-path /root/project/SpecForge/data/qwen2.5vl_detect_448x840_detection_convert.json \
    --output-dir /mnt/output/Qwen2.5-VL-eagle3 \
    --num-epochs 3 \
    --batch-size 8 \
    --learning-rate 1e-4 \
    --max-length 8192 \
    --dist-timeout 360 \
    --chat-template qwen2-vl \
    --cache-dir $ROOT_DIR/cache \
    --embedding-key model.embed_tokens.weight \
    --tp-size 1 \
    --is-vlm \
    --min-pixels 50176 \
    --max-pixels 802816 \
    --target-model-backend sglang \
    --save-interval 250 \
    --build-dataset-num-proc 4 \
    --draft-accumulation-steps 2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.