Fix FlashMLA Shared-Memory Overflow in SGLang's Pure-TP Mode with Low-SMEM Fallback Scheduler#2
Merged
Fridge003 merged 2 commits intosgl-project:sglfrom Nov 20, 2025
Conversation
6 tasks
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
When running SGLang + DeepSeek V3.2 in pure Tensor Parallel (TP) mode, the NSA backend can produce large expanded sequence lists (
seqlens_k)—especially whentopk=2048and batch sizes are large.Under these workloads, the existing
get_mla_metadata_kernelallocates dynamic shared memory proportional tobatch_size:In pure-TP setups,
batch_sizecan easily exceed 10k–20k expanded rows, causingsmem_sizeto exceed the GPU’s shared memory limit (e.g., on Hopper/Blackwell). As a result:cudaFuncSetAttributefails due to insufficient shared memoryThis prevents SGLang from using FlashMLA efficiently in DeepSeek V3.2 pure-TP mode — a major limitation for users running large-scale inference.
Modifications
1. Added a low-shared-memory fallback kernel
Introduced:
Key characteristics:
Uses no dynamic shared memory
Only
threadIdx.x == 0performs the scheduling computationRecomputes
num_blocks,first/last_block_idx, and writes:tile_scheduler_metadata_ptrnum_splits_ptrScheduling logic is semantically identical to the original high-sMem kernel
Removes all
O(batch_size)shared-memory allocationsThis ensures correctness while avoiding sMem overflow entirely.
2. Kernel selection controlled at runtime
Updated
run_get_mla_metadata_kernel:Compute required dynamic shared memory:
smem_size = sizeof(int) * (batch_size * 5 + 1);Query device limits via
cudaDevAttrMaxSharedMemoryPerBlockOptinDecision:
smem_size <= max_smem:→ Use original high-performance shared-memory kernel
→ Fallback to
get_mla_metadata_kernel_low_smem(zero dynamic sMem)
Both kernels preserve existing APIs and metadata formats.
Impact