Skip to content

Refactor MLA kv cache#21835

Open
nvjullin wants to merge 11 commits into
sgl-project:mainfrom
nvjullin:refactor-mla-kv-cache
Open

Refactor MLA kv cache#21835
nvjullin wants to merge 11 commits into
sgl-project:mainfrom
nvjullin:refactor-mla-kv-cache

Conversation

@nvjullin
Copy link
Copy Markdown
Contributor

@nvjullin nvjullin commented Apr 1, 2026

Motivation

See #21011. This PR splits the flashmla part out for easier management, which will be submitted later.
This PR also fixes several merge issues with hisparse.

Modifications

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@nvjullin nvjullin changed the title Refactor mla kv cache Refactor MLA kv cache Apr 1, 2026
@Fridge003 Fridge003 self-assigned this Apr 1, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MLA KV cache management by introducing the MLAKVCacheLayout enum and kv_cache_size attribute, replacing the previous boolean-based and dimension-based tracking. These changes are propagated through the memory pools, model runner, and attention backends to provide a more robust way of handling different quantization layouts (FP4, BF16, FP8). A critical regression was identified in the nsa_backend.py where the refactored flashmla_sparse path incorrectly raises a ValueError for BF16 layouts. This occurs because BF16 defaults to the PAGED transform method, while the new logic strictly requires RAGGED, leading to a runtime crash for those models.

Comment on lines +1417 to +1428
if topk_transform_method != TopkTransformMethod.RAGGED:
raise ValueError(
"Internal error: Unexpected topk transform method for NSA backend flashmla_sparse."
)

if any(forward_batch.extend_prefix_lens_cpu):
page_table_1_flattened = self.forward_metadata.page_table_1_flattened
assert page_table_1_flattened is not None
kv_cache = dequantize_k_cache_paged(kv_cache, page_table_1_flattened)
else:
kv_cache = _cat([k, k_rope], dim=-1)
page_table_1 = topk_indices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refactoring of the flashmla_sparse path in forward_extend introduces a regression for models using the BF16 layout.

Currently, get_topk_transform_method returns TopkTransformMethod.PAGED for the BF16 layout (line 2153). However, the new code in forward_extend explicitly raises a ValueError if the method is not RAGGED (line 1417). This means any BF16 model using the NSA backend with flashmla_sparse (which is the default for BF16 in set_nsa_impl, line 2141) will crash at runtime.

By restoring the conditional check for TopkTransformMethod.RAGGED, the PAGED method (used by BF16) can correctly proceed using the physical indices in page_table_1 (computed at line 1388) and the global kv_cache buffer, while the RAGGED specific logic (including dequantization) remains restricted to the appropriate layouts.

Suggested change
if topk_transform_method != TopkTransformMethod.RAGGED:
raise ValueError(
"Internal error: Unexpected topk transform method for NSA backend flashmla_sparse."
)
if any(forward_batch.extend_prefix_lens_cpu):
page_table_1_flattened = self.forward_metadata.page_table_1_flattened
assert page_table_1_flattened is not None
kv_cache = dequantize_k_cache_paged(kv_cache, page_table_1_flattened)
else:
kv_cache = _cat([k, k_rope], dim=-1)
page_table_1 = topk_indices
if topk_transform_method == TopkTransformMethod.RAGGED:
if any(forward_batch.extend_prefix_lens_cpu):
page_table_1_flattened = self.forward_metadata.page_table_1_flattened
assert page_table_1_flattened is not None
kv_cache = dequantize_k_cache_paged(kv_cache, page_table_1_flattened)
else:
kv_cache = _cat([k, k_rope], dim=-1)
page_table_1 = topk_indices

if forward_mode is None or forward_mode.is_decode_or_idle():
return TopkTransformMethod.PAGED
elif (
self.kv_cache_layout == MLAKVCacheLayout.FP8_NOPE_WITH_BLOCK_SCALE_BF16_ROPE
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a faithful refactor from the original code, but got flagged as a bug by gemini. I don't believe flashmla_sparse accepts PAGED topk though, so this seems like a bug in the original code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants