integrate flash_mla_sparse_fwd#25418
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
@yuan-luo Could you help test on H20 with flag |
| if compressed_slice is not None: | ||
| dequantize_k_cache_paged( | ||
| extra_k_cache, | ||
| flat_token_ids, | ||
| page_size=extra_page_size, | ||
| out=compressed_slice, | ||
| ) |
There was a problem hiding this comment.
Will this dequant all the c4 cache, or only the selected c4 cache?
There was a problem hiding this comment.
Looks like all the c4 cache. Since selected c4 cache changes between different layers, but the sparse prefill cache here is only computed once at the first layer (so the value of flat_token_ids doesn't change). Need @zcnrex to confirm this
| c4_sparse: means "compressed by 4" but only attend to top-512 tokens. | ||
| all related length will be clipped to 512. | ||
| """ | ||
| _LARGE_INDEXER_QUERY_THRESHOLD = 11673 |
There was a problem hiding this comment.
Why hardcode to this value. Can we avoid hardcoding
| if envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get(): | ||
| use_jit_indexer = ( | ||
| envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get() | ||
| or self.c4_seq_lens.numel() > _LARGE_INDEXER_QUERY_THRESHOLD |
There was a problem hiding this comment.
Why is it numel here? I think the numel of c4_seq_lens should be a small value like batch size?
| fp8_vals = tl.load(buf_fp8_ptr + fp8_off).to(tl.float32) | ||
|
|
||
| scale_u8 = tl.load(buf_uint8_ptr + token_scale_base + tile_id).to(tl.int32) | ||
| scale_pow2 = tl.exp2((scale_u8 - 127).to(tl.float32)) |
There was a problem hiding this comment.
Will tl.int32 and tl.float32 cause overflow here?
| @@ -0,0 +1,584 @@ | |||
| """Per-query sparse-index combiner for the FlashMLA sparse prefill path. | |||
|
|
|||
| Adapts vllm's ``combine_topk_swa_indices`` (vllm/v1/attention/ops/ | |||
There was a problem hiding this comment.
Paste the link to reference file
| if compressed_slice is not None: | ||
| dequantize_k_cache_paged( | ||
| extra_k_cache, | ||
| flat_token_ids, | ||
| page_size=extra_page_size, | ||
| out=compressed_slice, | ||
| ) |
There was a problem hiding this comment.
Looks like all the c4 cache. Since selected c4 cache changes between different layers, but the sparse prefill cache here is only computed once at the first layer (so the value of flat_token_ids doesn't change). Need @zcnrex to confirm this
|
|
||
| Adapts vllm's ``combine_topk_swa_indices`` (vllm/v1/attention/ops/ | ||
| deepseek_v4_ops/cache_utils.py) to sglang's flat-workspace layout. For each | ||
| query token in a prefill chunk, emits one row of combined indices into the |
There was a problem hiding this comment.
For we change all the index tensors to tl.int64 in this file. int32 is prone to IMA
| @@ -0,0 +1,136 @@ | |||
| from typing import Optional | |||
There was a problem hiding this comment.
Can we add a torch ref impl here, and add an output comparison test between ref & applied kerenl (maybe under __main__)
Motivation
flash_mla_with_kvcacheis slow because of complicated loading logic. Switching toflash_mla_sparse_fwdsee 1.35x speedup compared to theflash_mla_with_kvcachekernelModifications
flash_mla_sparse_fwdfor prefill instead offlash_mla_with_kvcache. This is faster because it directly uses TMA load and avoid complicated loading logic.flash_mla_sparse_fwdwork.flash_mla_sparse_fwdkernel to avoid a bug [Bug] DeepSeek-V4-Pro on 8 * H20-3e: DeepGEMMpaged_mqa_logits_metadatakernel exceeds shared memory limit during JIT compilation #25484.Accuracy Tests
Chunk prefill size 8192
Chunk prefill size 32768
Speed Tests and Profiling
Prefill 128k, 4 x B200, DSv4 Flash,
--chunked-prefill-size 8192Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.