Skip to content

[bug] fix errors related to context length in SD#9388

Merged
hnyls2002 merged 16 commits intomainfrom
lsyin/fix-spec-context-length
Aug 21, 2025
Merged

[bug] fix errors related to context length in SD#9388
hnyls2002 merged 16 commits intomainfrom
lsyin/fix-spec-context-length

Conversation

@hnyls2002
Copy link
Collaborator

@hnyls2002 hnyls2002 commented Aug 20, 2025

According to the document, SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN should be set to False by default.

When the draft model's context length is smaller than target model, there are tons of illegal memory access.

KV cache location IMA reproduce
MODEL=meta-llama/Meta-Llama-3.1-8B-Instruct
SPEC_MODEL=lmsys/sglang-EAGLE-LLaMA3-Instruct-8B
ATTN=${ATTN:-fa3}
NUM=${NUM:-1}
python3 -m sglang.bench_offline_throughput \
    --model-path $MODEL \
    --dataset-name random \
    --random-input-len 256 \
    --random-output-len 1750 \
    --random-range-ratio 1.0 \
    --num-prompts $NUM \
    --speculative-algorithm EAGLE \
    --speculative-draft-model-path $SPEC_MODEL \
    --speculative-num-steps 5 \
    --speculative-eagle-topk 8 \
    --speculative-num-draft-tokens 64 \
    --mem-fraction-static 0.7 \
    --cuda-graph-max-bs 2 \
    --dtype float16 \
    --disable-radix \
    --context-length 2048 \
    --allow-auto-truncate \
    --attention-backend $ATTN \
    --disable-cuda-graph
Error
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:163: operator(): block: [0,0,0], thread: [62,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "scatter gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:163: operator(): block: [0,0,0], thread: [63,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "scatter gather kernel index out of bounds"` failed.
[2025-08-20 01:17:22] Scheduler hit an exception: Traceback (most recent call last):
  File "/root/sglang/python/sglang/srt/managers/scheduler.py", line 2577, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/root/sglang/python/sglang/srt/managers/scheduler.py", line 783, in event_loop_normal
    result = self.run_batch(batch)
  File "/root/sglang/python/sglang/srt/managers/scheduler.py", line 1753, in run_batch
    ) = self.draft_worker.forward_batch_speculative_generation(batch)
  File "/root/sglang/python/sglang/srt/speculative/eagle_worker.py", line 343, in forward_batch_speculative_generation
    self.verify(batch, spec_info)
  File "/root/sglang/python/sglang/srt/speculative/eagle_worker.py", line 683, in verify
    self.target_worker.forward_batch_generation(
  File "/root/sglang/python/sglang/srt/managers/tp_worker.py", line 238, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
  File "/root/sglang/python/sglang/srt/model_executor/model_runner.py", line 1735, in forward
    output = self._forward_raw(
  File "/root/sglang/python/sglang/srt/model_executor/model_runner.py", line 1780, in _forward_raw
    ret = self.forward_extend(
  File "/root/sglang/python/sglang/srt/model_executor/model_runner.py", line 1671, in forward_extend
    self.attn_backend.init_forward_metadata(forward_batch)
  File "/root/sglang/python/sglang/srt/layers/attention/flashattention_backend.py", line 547, in init_forward_metadata
    .gather(1, cols)
torch.AcceleratorError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

In summary, this PR does two things:

  • Fixed the default value for overriding longer context length.
  • Fixed the IMA when speculation positions for the draft model exceeded the context length.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @hnyls2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug related to context length handling in the speculative decoding process. Previously, the draft model's context length was unconditionally overridden, which could lead to issues if it was smaller than the target model's required context. The changes introduce a safeguard to ensure the draft model's context length is always sufficient for the target model, preventing potential errors and improving stability.

Highlights

  • Context Length Adjustment: Implemented a conditional check to ensure the draft model's context length (server_args.context_length) is not smaller than the target model's context length (target_worker.model_runner.model_config.context_len). If it is, the draft model's context length is increased to match the target model's, and a warning is logged to prevent potential errors.
  • Dependency Cleanup: The import of get_tensor_model_parallel_world_size from sglang.srt.distributed was removed as it is no longer utilized in the eagle_worker.py file.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in the speculative decoding logic where the draft model's context length was always overridden by the target model's. The new implementation correctly overrides the context length only when it is smaller than the target model's, which is a more robust approach. The removal of an unused import is also a good cleanup. The changes are correct and improve the stability of the speculative decoding feature.

Comment on lines +94 to +102
if (
server_args.context_length
< target_worker.model_runner.model_config.context_len
):
target_context_len = target_worker.model_runner.model_config.context_len
logger.warning(
f"Context length of the draft model ({server_args.context_length}) is smaller than the target model ({target_context_len}), overriding the context length to {target_context_len}."
)
server_args.context_length = target_context_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and to avoid repeating the long attribute access, you could define target_context_len before the if statement. This makes the condition check cleaner and the code slightly more DRY (Don't Repeat Yourself).

        target_context_len = target_worker.model_runner.model_config.context_len
        if server_args.context_length < target_context_len:
            logger.warning(
                f"Context length of the draft model ({server_args.context_length}) is smaller than the target model ({target_context_len}), overriding the context length to {target_context_len}."
            )
            server_args.context_length = target_context_len

@hnyls2002 hnyls2002 marked this pull request as draft August 20, 2025 07:45
@hnyls2002 hnyls2002 marked this pull request as ready for review August 20, 2025 15:38
@hnyls2002 hnyls2002 merged commit eb19cca into main Aug 21, 2025
4 of 41 checks passed
@hnyls2002 hnyls2002 deleted the lsyin/fix-spec-context-length branch August 21, 2025 02:32
@Lzhang-hub Lzhang-hub mentioned this pull request Aug 31, 2025
6 tasks
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
@DevashishLal-CB
Copy link
Contributor

Are we supposed to manually set the context length as server args ?

for llama 3 8b instruct the model has a context length of 8k, the lmsys draft model for the same has a context length of 2K
the logs in the docs also have this warning but I am assuming they were run with the env var to skip this check as true

@JustinTong0323
Copy link
Collaborator

Are we supposed to manually set the context length as server args ?

for llama 3 8b instruct the model has a context length of 8k, the lmsys draft model for the same has a context length of 2K the logs in the docs also have this warning but I am assuming they were run with the env var to skip this check as true

@DevashishLal-CB Thanks for pointing this out! We fix it in #10787

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants