Skip to content

Fix FlashMLA Shared-Memory Overflow in SGLang's Pure-TP Mode with Low-SMEM Fallback Scheduler#2

Merged
Fridge003 merged 2 commits intosgl-project:sglfrom
YAMY1234:metadata_fallback
Nov 20, 2025
Merged

Fix FlashMLA Shared-Memory Overflow in SGLang's Pure-TP Mode with Low-SMEM Fallback Scheduler#2
Fridge003 merged 2 commits intosgl-project:sglfrom
YAMY1234:metadata_fallback

Conversation

@YAMY1234
Copy link
Copy Markdown

@YAMY1234 YAMY1234 commented Nov 20, 2025

Motivation

When running SGLang + DeepSeek V3.2 in pure Tensor Parallel (TP) mode, the NSA backend can produce large expanded sequence lists (seqlens_k)—especially when topk=2048 and batch sizes are large.

Under these workloads, the existing get_mla_metadata_kernel allocates dynamic shared memory proportional to batch_size:

smem_size = sizeof(int) * (batch_size * 5 + 1);

In pure-TP setups, batch_size can easily exceed 10k–20k expanded rows, causing smem_size to exceed the GPU’s shared memory limit (e.g., on Hopper/Blackwell). As a result:

  • cudaFuncSetAttribute fails due to insufficient shared memory
  • The metadata kernel cannot launch
  • FlashMLA KV decoding path crashes, even though the downstream compute kernels can run correctly

This prevents SGLang from using FlashMLA efficiently in DeepSeek V3.2 pure-TP mode — a major limitation for users running large-scale inference.


Modifications

1. Added a low-shared-memory fallback kernel

Introduced:

__global__ void get_mla_metadata_kernel_low_smem(const GetDecodingMetadataParams params);

Key characteristics:

  • Uses no dynamic shared memory

  • Only threadIdx.x == 0 performs the scheduling computation

  • Recomputes num_blocks, first/last_block_idx, and writes:

    • tile_scheduler_metadata_ptr
    • num_splits_ptr
  • Scheduling logic is semantically identical to the original high-sMem kernel

  • Removes all O(batch_size) shared-memory allocations

This ensures correctness while avoiding sMem overflow entirely.


2. Kernel selection controlled at runtime

Updated run_get_mla_metadata_kernel:

  • Compute required dynamic shared memory:
    smem_size = sizeof(int) * (batch_size * 5 + 1);

  • Query device limits via cudaDevAttrMaxSharedMemoryPerBlockOptin

  • Decision:

    • If smem_size <= max_smem:
      → Use original high-performance shared-memory kernel
    • Else (typical in DeepSeek V3.2 pure-TP mode):
      → Fallback to get_mla_metadata_kernel_low_smem
      (zero dynamic sMem)

Both kernels preserve existing APIs and metadata formats.


Impact

  • Pure-TP workloads for SGLang + DeepSeek V3.2 now run stably, without crashing due to shared-memory overflow in the scheduler.
  • Performance for normal-sized batches is unchanged (fast-path unchanged).
  • For extremely large batches, the fallback kernel performs slightly more sequential work, but its runtime remains negligible compared to FlashMLA compute kernels.
  • Enables robust FlashMLA KV decoding for DeepSeek V3.2 models under TP-only deployments — previously impossible without reducing batch size or disabling FlashMLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants