Refactor MLA kv cache#21835
Conversation
This reverts commit 46dc55e.
There was a problem hiding this comment.
Code Review
This pull request refactors the MLA KV cache management by introducing the MLAKVCacheLayout enum and kv_cache_size attribute, replacing the previous boolean-based and dimension-based tracking. These changes are propagated through the memory pools, model runner, and attention backends to provide a more robust way of handling different quantization layouts (FP4, BF16, FP8). A critical regression was identified in the nsa_backend.py where the refactored flashmla_sparse path incorrectly raises a ValueError for BF16 layouts. This occurs because BF16 defaults to the PAGED transform method, while the new logic strictly requires RAGGED, leading to a runtime crash for those models.
| if topk_transform_method != TopkTransformMethod.RAGGED: | ||
| raise ValueError( | ||
| "Internal error: Unexpected topk transform method for NSA backend flashmla_sparse." | ||
| ) | ||
|
|
||
| if any(forward_batch.extend_prefix_lens_cpu): | ||
| page_table_1_flattened = self.forward_metadata.page_table_1_flattened | ||
| assert page_table_1_flattened is not None | ||
| kv_cache = dequantize_k_cache_paged(kv_cache, page_table_1_flattened) | ||
| else: | ||
| kv_cache = _cat([k, k_rope], dim=-1) | ||
| page_table_1 = topk_indices |
There was a problem hiding this comment.
The refactoring of the flashmla_sparse path in forward_extend introduces a regression for models using the BF16 layout.
Currently, get_topk_transform_method returns TopkTransformMethod.PAGED for the BF16 layout (line 2153). However, the new code in forward_extend explicitly raises a ValueError if the method is not RAGGED (line 1417). This means any BF16 model using the NSA backend with flashmla_sparse (which is the default for BF16 in set_nsa_impl, line 2141) will crash at runtime.
By restoring the conditional check for TopkTransformMethod.RAGGED, the PAGED method (used by BF16) can correctly proceed using the physical indices in page_table_1 (computed at line 1388) and the global kv_cache buffer, while the RAGGED specific logic (including dequantization) remains restricted to the appropriate layouts.
| if topk_transform_method != TopkTransformMethod.RAGGED: | |
| raise ValueError( | |
| "Internal error: Unexpected topk transform method for NSA backend flashmla_sparse." | |
| ) | |
| if any(forward_batch.extend_prefix_lens_cpu): | |
| page_table_1_flattened = self.forward_metadata.page_table_1_flattened | |
| assert page_table_1_flattened is not None | |
| kv_cache = dequantize_k_cache_paged(kv_cache, page_table_1_flattened) | |
| else: | |
| kv_cache = _cat([k, k_rope], dim=-1) | |
| page_table_1 = topk_indices | |
| if topk_transform_method == TopkTransformMethod.RAGGED: | |
| if any(forward_batch.extend_prefix_lens_cpu): | |
| page_table_1_flattened = self.forward_metadata.page_table_1_flattened | |
| assert page_table_1_flattened is not None | |
| kv_cache = dequantize_k_cache_paged(kv_cache, page_table_1_flattened) | |
| else: | |
| kv_cache = _cat([k, k_rope], dim=-1) | |
| page_table_1 = topk_indices |
| if forward_mode is None or forward_mode.is_decode_or_idle(): | ||
| return TopkTransformMethod.PAGED | ||
| elif ( | ||
| self.kv_cache_layout == MLAKVCacheLayout.FP8_NOPE_WITH_BLOCK_SCALE_BF16_ROPE |
There was a problem hiding this comment.
This is a faithful refactor from the original code, but got flagged as a bug by gemini. I don't believe flashmla_sparse accepts PAGED topk though, so this seems like a bug in the original code?
Motivation
See #21011. This PR splits the flashmla part out for easier management, which will be submitted later.
This PR also fixes several merge issues with hisparse.
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci