Fix Illegal Instruction/IMA errors when using DP attention -- num_tokens_for_logprob calculation#12115
Conversation
Summary of ChangesHello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical out-of-bounds memory access error that manifested as an Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request effectively addresses a critical out-of-bounds memory access error that occurs when using DP attention without returning logprobs. The root cause—an incorrect calculation of num_tokens_for_logprob—is well-understood, and the proposed fix correctly adjusts the logic based on the return_logprob flag. The change is clear, targeted, and should resolve the reported issue. I've included one minor suggestion to improve code style and memory efficiency.
|
wonderful job! |
Motivation
Finding the root fix for #12052
Issue: When running with --enable-dp-attention, the DP gather operation in logits stage uses incorrect metadata, causing Triton kernel to access out-of-bounds memory.
Root cause: The scheduler calculates num_tokens_for_logprob incorrectly by always assuming all tokens need logits computation, but when return_logprob=False, LogitsProcessor only constructs pruned_states with the last token per request (batch_size tokens). This mismatch causes DP gather to expect 661 tokens but only receive 3 tokens, leading to out-of-bounds memory access.
Modifications
Key debug log before crash:
File:
sglang/python/sglang/srt/managers/scheduler.pyIn
prepare_mlp_sync_batch_raw(), distinguish two cases when calculatingnum_tokens_for_logprobfor extend mode:When
return_logprob=True: Keep original logic -sum of max(extend_len - logprob_start_len, 1)across all requests (needs logits for input logprob computation)When
return_logprob=False: Usebatch_size()- only need last token per request for samplingThis ensures the token count synchronized across DP ranks matches the actual pruned_states size in
LogitsProcessor, preventing the memory access violation.Accuracy Tests
Benchmarking and Profiling
Checklist