Support mixing NSA flashmla prefill and trtllm decode kernels#21011
Support mixing NSA flashmla prefill and trtllm decode kernels#21011nvjullin wants to merge 12 commits into
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and robustness of the attention mechanisms by enabling the seamless mixing of different KV-cache prefill and decode backends, specifically allowing Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request is a significant and well-executed refactoring of the NSA/MLA KV cache handling. It resolves confusing logic and bugs by making the KV cache layout explicit with the MLAKVCacheLayout enum. The addition of the requantize_fp8_to_block_scale kernel is a key feature that enables mixing flashmla prefill with trtllm decode, which is a great new capability. The refactoring in nsa_backend.py and model_runner_kv_cache_mixin.py makes the code much clearer and more maintainable. I've identified a potential bug in memory_pool.py that could lead to a NameError and have suggested a fix. Overall, this is a high-quality contribution.
|
@nvjullin could you provide GLM-5 or DSV3.2 accuracy numbers with this change? thanks |
|
Added gsm8k results, which looks correct after latest commit. |
| ): | ||
| layer_id = layer.layer_id | ||
|
|
||
| if self.nsa_kv_cache_store_fp8: |
There was a problem hiding this comment.
This branch is unreachable because nsa_kv_cache_store_fp8 actually means FP8_NOPE_BLOCK_SCALES_FP8_ROPE.
|
@rainj-me @Fridge003 could you review this? Thanks! |
This reverts commit 15e74d5.
|
I think it's mismatching the goal of #20163. We are trying to apply flashmla_sparse as prefill backend, since flashmla_kv doesn't perform well under prefill workload |
| # offset of each request in the ragged kv cache | ||
| # topk indices needs to transform from local indices to global indices by | ||
| # global_topk_indices = local_topk_indices + topk_indices_offset | ||
| # offset is repeated once for each token in the request in a flattened array. |
There was a problem hiding this comment.
@Fridge003 For future reference, what is the preferred comment style in sglang?
What is the main concern against verbose comments?
There was a problem hiding this comment.
I was deleting some repeated verbose comments, which seems like auto-generated
But this one is good so we can add it back
Most of this PR is scaffolding to detect and dispatch to the correct code path for mixed backends, and is large enough as-is. |
|
Also can you please split this PR into two parts:
|
Sure I can, but supporting flashmla_kv prefill + trtllm decode is dependent on the kv-cache refactoring, they're not independent. The first will still be blocked on the second. For example, if self.kv_cache_layout == MLAKVCacheLayout.FP8_NOPE_FP8_ROPE:
# Mixed scenario: transform trtllm-native → block-scale for flashmla_kv
kv_cache = requantize_fp8_to_block_scale(kv_cache)
elif (
self.kv_cache_layout != MLAKVCacheLayout.FP8_NOPE_WITH_BLOCK_SCALE_BF16_ROPE
):
# BF16: inefficiently quantize the whole cache
kv_cache = quantize_k_cache(kv_cache)used to be if not self.nsa_kv_cache_store_fp8:
kv_cache = quantize_k_cache(kv_cache)This needs to differentiate between Do you still want me to break the PR up in this case? |
|
After offline discussion, let's break this up into kv-cache refactor + mixed backend PRs. |
Motivation
Clean up is composed of cleaning up
topk_transform_methodand cleaning up kv-cache layout.Clean up
topk_transform_method--nsa-prefill-backend flashmla_sparsewill crash with an unhelpful errorwhile
--nsa-prefill-backend flashmla_autoruns correctly, even whenflashmla_autoresolves toflashmla_sparse.flashmla_autorunning is a miracle of two bugs coincidentally cancelling each other out.sglang/python/sglang/srt/layers/attention/nsa_backend.py
Lines 403 to 404 in 8b46f1f
When running decode,
sglang/python/sglang/srt/layers/attention/nsa_backend.py
Lines 2063 to 2076 in 8b46f1f
set_nsa_prefill_implwill setnsa_prefill_impltoflashmla_kvbecauseforward_modeis not extend. This is wrong, decode forward should not be settingnsa_prefill_impl.Then,
sglang/python/sglang/srt/layers/attention/nsa_backend.py
Lines 2081 to 2094 in 8b46f1f
will accidentally select the correct
topk_transform_methodas paged. This is also wrong, decode should not be consultingnsa_prefill_impl.While using
--nsa-prefill-backend flashmla_sparse,set_nsa_prefill_impldoes nothing andtopk_transform_methodselects ragged, which causes the crash.Clean up kv-cache layout
Previously, there were several MLA kv-cache layouts that were implicitly assumed throughout the code base, never enforced/asserted on/commented on. In particular, there are some bugs around
nsa_kv_cache_store_fp8actually meaning "fp8 kv-cache with flashmla layout" but being used as "fp8 kv-cache". Now the layout is an explicit enum.Add backend mixing
After the clean up, it's now possible to detect mixed kv-cache layout. For the initial implementation, it's only possible to run flashmla_kv prefill + trtllm decode for mixed kv-cache layout. flashmla_auto will resolve to flashmla_kv. Support for flashmla_sparse can be added in a later PR.
The implementation is to store kv-cache in trtllm-layout, then use a triton kernel to convert trtllm layout to flashmla layout. It simply fills block scales uniformly as 1 and dequants k_rope. The conversion converts the entire kv-cache, so it's better to convert once early (in prefill), than later multiple times (in decode).
Accuracy Tests
Manual runs show all supported combinations of flashmla and trtllm responds with proper sentences. The triton kernel has a unit test to test that it behaves as expected.
Running trtllm decode backend
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci