Skip to content

fix: draft model IMA by overide max_positional_embeddings#10787

Merged
zhyncs merged 3 commits intosgl-project:mainfrom
JustinTong0323:fix-draft-model-context
Sep 23, 2025
Merged

fix: draft model IMA by overide max_positional_embeddings#10787
zhyncs merged 3 commits intosgl-project:mainfrom
JustinTong0323:fix-draft-model-context

Conversation

@JustinTong0323
Copy link
Collaborator

@JustinTong0323 JustinTong0323 commented Sep 23, 2025

Motivation

Fix #10713
Root cause is the rope cache calculated by model's max_positional_embeddings, which is default to 2048. So if the input(after batching) longer then this number, IMA occurs.

def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Co-authored-by: Qiaolin Yu <qy254@cornell.edu>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @JustinTong0323, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue (Fix #10713) by implementing a targeted fix for draft models. The core change ensures that when a draft model's context length is adjusted, its max_position_embeddings attribute is also correctly updated to reflect this change. This prevents potential discrepancies in how the model handles positional information, thereby improving the robustness and correctness of draft model configurations.

Highlights

  • Draft Model Configuration Fix: When a draft model's context length is explicitly overridden, its max_position_embeddings will now also be updated to match the new context_length. This ensures consistency in the model's configuration and prevents potential issues related to positional embeddings.
  • Logging for Overrides: A warning message is now logged whenever a draft model's max_position_embeddings is overridden, providing clearer visibility into configuration adjustments.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zhyncs
Copy link
Collaborator

zhyncs commented Sep 23, 2025

@JustinTong0323 can you add ut for this

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for handling draft models with extended context lengths. When a context_length is provided that is longer than the model's default, and overrides are enabled, the change now also updates the max_position_embeddings in the draft model's configuration. This ensures that the draft model's positional embeddings are correctly sized for the extended context, which is crucial for speculative decoding scenarios. The change is well-placed and includes a helpful warning log. The implementation looks correct and follows the existing patterns in the codebase.

@JustinTong0323
Copy link
Collaborator Author

Maybe merge with #10788

):
logger.warning(msg)
self.context_len = context_length
if is_draft_model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") get environ var from sglang.srt.environ

@zhyncs zhyncs merged commit aab35bc into sgl-project:main Sep 23, 2025
68 of 77 checks passed
HanHan009527 pushed a commit to HanHan009527/sglang that referenced this pull request Oct 9, 2025
@JustinTong0323 JustinTong0323 deleted the fix-draft-model-context branch October 20, 2025 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] illegal memory of BatchQKApplyRotaryPosIdsCosSinCache when spec decoding

4 participants