Skip to content

Support mixing NSA flashmla prefill and trtllm decode kernels#21011

Closed
nvjullin wants to merge 12 commits into
sgl-project:mainfrom
nvjullin:add-flashmla-kv-cache-layout
Closed

Support mixing NSA flashmla prefill and trtllm decode kernels#21011
nvjullin wants to merge 12 commits into
sgl-project:mainfrom
nvjullin:add-flashmla-kv-cache-layout

Conversation

@nvjullin
Copy link
Copy Markdown
Contributor

@nvjullin nvjullin commented Mar 20, 2026

Motivation

  1. Clean up very confusing code around NSA/MLA kv-cache
  2. Add ability to mix flashmla_kv prefill with trtllm decode backend, solving [Feature] Support fp8 kv cache + trtllm decode DSA attention + sparse prefill DSA attention #20163

Clean up is composed of cleaning up topk_transform_method and cleaning up kv-cache layout.

Clean up topk_transform_method

--nsa-prefill-backend flashmla_sparse will crash with an unhelpful error

RuntimeError: topk_indices_offset must be a CUDA tensor

while --nsa-prefill-backend flashmla_auto runs correctly, even when flashmla_auto resolves to flashmla_sparse.

flashmla_auto running is a miracle of two bugs coincidentally cancelling each other out.

self.set_nsa_prefill_impl(forward_batch)
topk_transform_method = self.get_topk_transform_method()

When running decode,
if not self.use_mha and self.enable_auto_select_prefill_impl:
if self.nsa_kv_cache_store_fp8:
if (
is_blackwell()
and forward_batch is not None
and forward_batch.forward_mode == ForwardMode.EXTEND
):
total_kv_tokens = forward_batch.seq_lens_sum
total_q_tokens = forward_batch.extend_num_tokens
# Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged
if total_kv_tokens < total_q_tokens * 512:
self.nsa_prefill_impl = "flashmla_sparse"
return
self.nsa_prefill_impl = "flashmla_kv"

set_nsa_prefill_impl will set nsa_prefill_impl to flashmla_kv because forward_mode is not extend. This is wrong, decode forward should not be setting nsa_prefill_impl.
Then,
def get_topk_transform_method(self) -> TopkTransformMethod:
"""
SGLANG_NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.
This method is used to select the topk transform method which can be fused or unfused.
"""
if (
# disable for MTP
self.nsa_kv_cache_store_fp8
and self.nsa_prefill_impl == "flashmla_sparse"
):
topk_transform_method = TopkTransformMethod.RAGGED
else:
topk_transform_method = TopkTransformMethod.PAGED
return topk_transform_method

will accidentally select the correct topk_transform_method as paged. This is also wrong, decode should not be consulting nsa_prefill_impl.

While using --nsa-prefill-backend flashmla_sparse, set_nsa_prefill_impl does nothing and topk_transform_method selects ragged, which causes the crash.

Clean up kv-cache layout

Previously, there were several MLA kv-cache layouts that were implicitly assumed throughout the code base, never enforced/asserted on/commented on. In particular, there are some bugs around nsa_kv_cache_store_fp8 actually meaning "fp8 kv-cache with flashmla layout" but being used as "fp8 kv-cache". Now the layout is an explicit enum.

Add backend mixing

After the clean up, it's now possible to detect mixed kv-cache layout. For the initial implementation, it's only possible to run flashmla_kv prefill + trtllm decode for mixed kv-cache layout. flashmla_auto will resolve to flashmla_kv. Support for flashmla_sparse can be added in a later PR.

The implementation is to store kv-cache in trtllm-layout, then use a triton kernel to convert trtllm layout to flashmla layout. It simply fills block scales uniformly as 1 and dequants k_rope. The conversion converts the entire kv-cache, so it's better to convert once early (in prefill), than later multiple times (in decode).

Accuracy Tests

Manual runs show all supported combinations of flashmla and trtllm responds with proper sentences. The triton kernel has a unit test to test that it behaves as expected.

Running trtllm decode backend

## GSM8K Accuracy

| parallelism | kv  | quant | prefill-backend | score |   std |
| ----------- | --- | ----- | --------------- | ----: | ----: |
| dep8        | fp8 | fp4   | flashmla_auto   | 0.965 | 0.184 |
| dep8        | fp8 | fp4   | trtllm          | 0.965 | 0.184 |
| dep8        | fp8 | fp8   | flashmla_auto   | 0.945 | 0.228 |
| dep8        | fp8 | fp8   | trtllm          | 0.945 | 0.228 |

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and robustness of the attention mechanisms by enabling the seamless mixing of different KV-cache prefill and decode backends, specifically allowing flashmla_kv prefill with trtllm decode. This is achieved through a comprehensive cleanup of existing KV-cache management code, the introduction of an explicit KV-cache layout enumeration, and the implementation of a new Triton kernel for on-the-fly layout conversion. These changes ensure greater compatibility and correct operation in diverse mixed-backend environments.

Highlights

  • KV-Cache Layout Refactoring: Cleaned up confusing code around NSA/MLA KV-cache and introduced an explicit MLAKVCacheLayout enum to manage different KV-cache layouts, improving clarity and correctness.
  • Mixed Backend Support: Enabled the ability to mix flashmla_kv prefill with trtllm decode backends, addressing a long-standing issue and enhancing flexibility.
  • Triton Kernel for Layout Conversion: Added a new Triton kernel (requantize_fp8_to_block_scale) to convert KV-cache from trtllm-native FP8_NOPE_FP8_ROPE format to flashmla-compatible FP8_NOPE_WITH_BLOCK_SCALE_BF16_ROPE format on-the-fly.
  • TopK Transform Method Logic: Refactored the topk_transform_method selection logic to resolve previous bugs and ensure correct behavior across different attention backends and forward modes.
  • Unified Backend Dispatch: Renamed and refactored set_nsa_prefill_impl to set_nsa_impl, centralizing and streamlining the attention dispatch strategies for both prefill and decode operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of the NSA/MLA KV cache handling. It resolves confusing logic and bugs by making the KV cache layout explicit with the MLAKVCacheLayout enum. The addition of the requantize_fp8_to_block_scale kernel is a key feature that enables mixing flashmla prefill with trtllm decode, which is a great new capability. The refactoring in nsa_backend.py and model_runner_kv_cache_mixin.py makes the code much clearer and more maintainable. I've identified a potential bug in memory_pool.py that could lead to a NameError and have suggested a fix. Overall, this is a high-quality contribution.

Comment thread python/sglang/srt/mem_cache/memory_pool.py
Comment thread python/sglang/srt/mem_cache/memory_pool_host.py
@nvpohanh
Copy link
Copy Markdown
Collaborator

@nvjullin could you provide GLM-5 or DSV3.2 accuracy numbers with this change? thanks

@nvjullin
Copy link
Copy Markdown
Contributor Author

Added gsm8k results, which looks correct after latest commit.

):
layer_id = layer.layer_id

if self.nsa_kv_cache_store_fp8:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is unreachable because nsa_kv_cache_store_fp8 actually means FP8_NOPE_BLOCK_SCALES_FP8_ROPE.

@nvpohanh
Copy link
Copy Markdown
Collaborator

@rainj-me @Fridge003 could you review this? Thanks!

@Fridge003 Fridge003 requested a review from 1am9trash as a code owner March 30, 2026 23:13
Comment thread python/sglang/srt/layers/attention/nsa_backend.py
Comment thread python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
@Fridge003
Copy link
Copy Markdown
Collaborator

I think it's mismatching the goal of #20163. We are trying to apply flashmla_sparse as prefill backend, since flashmla_kv doesn't perform well under prefill workload

# offset of each request in the ragged kv cache
# topk indices needs to transform from local indices to global indices by
# global_topk_indices = local_topk_indices + topk_indices_offset
# offset is repeated once for each token in the request in a flattened array.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 For future reference, what is the preferred comment style in sglang?
What is the main concern against verbose comments?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was deleting some repeated verbose comments, which seems like auto-generated
But this one is good so we can add it back

@nvjullin
Copy link
Copy Markdown
Contributor Author

I think it's mismatching the goal of #20163. We are trying to apply flashmla_sparse as prefill backend, since flashmla_kv doesn't perform well under prefill workload

Most of this PR is scaffolding to detect and dispatch to the correct code path for mixed backends, and is large enough as-is.
flashmla_auto can also dispatch to flashmla_kv, so both need support. I picked one that seems easier to implement first. flashmla_sparse can be added in a later PR.

@Fridge003
Copy link
Copy Markdown
Collaborator

Also can you please split this PR into two parts:

  • The first part for supporting flashmla_kv prefill + trtllm decode, which can be merged quickly
  • The second part for refactoring kv cache. This part requires careful testing, to avoid conflicts with hi-sparse

@nvjullin
Copy link
Copy Markdown
Contributor Author

nvjullin commented Apr 1, 2026

Also can you please split this PR into two parts:

Sure I can, but supporting flashmla_kv prefill + trtllm decode is dependent on the kv-cache refactoring, they're not independent. The first will still be blocked on the second. For example,

        if self.kv_cache_layout == MLAKVCacheLayout.FP8_NOPE_FP8_ROPE:
            # Mixed scenario: transform trtllm-native → block-scale for flashmla_kv
            kv_cache = requantize_fp8_to_block_scale(kv_cache)
        elif (
            self.kv_cache_layout != MLAKVCacheLayout.FP8_NOPE_WITH_BLOCK_SCALE_BF16_ROPE
        ):
            # BF16: inefficiently quantize the whole cache
            kv_cache = quantize_k_cache(kv_cache)

used to be

        if not self.nsa_kv_cache_store_fp8:
            kv_cache = quantize_k_cache(kv_cache)

This needs to differentiate between FP8_NOPE_FP8_ROPE, FP8_NOPE_WITH_BLOCK_SCALE_BF16_ROPE and BF16. Solving nsa_kv_cache_store_fp8 without ad-hoc hacking (such as dispatching on kv-cache dim=576) is in effect what the kv-cache refactoring did.

Do you still want me to break the PR up in this case?

@nvjullin
Copy link
Copy Markdown
Contributor Author

nvjullin commented Apr 1, 2026

After offline discussion, let's break this up into kv-cache refactor + mixed backend PRs.
The latter will depend on the former.

@nvjullin nvjullin closed this Apr 1, 2026
@nvjullin nvjullin mentioned this pull request Apr 1, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants