[KVCache] PagedKVCache refactor, FlashInfer JIT and MLA integration #17674
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR consists of the following parts:
We reorganized
paged_kv_cache.ccby moving some of the utilities toattn_utils.h.To integrate with the JIT kernel compilation in the latest FlashInfer project, while still being able to support attention kernels written with TIR, we introduced
AttnBackendFuncinattn_backend.h, which exposes attention interfaces (e.g.,MHA,MLA) to PagedKVCache. We subclassAttnBackendFuncand implement FlashInfer backends and TIR backends respectively.With
AttnBackendFunc, we refactored the PagedKVCache constructor. The new constructor is not backward compatible, and will break the existing compiled model libraries.For both TIR and FlashInfer attention implementations, now we require an explicit attention softmax scale factor
sm_scaleto be passed in. Previously, it has an inlinedsm_scaleofhead_dim ** -0.5. Due to the recent LLM inference techniques such as MLA weight absorption in DeepSeek models, the inlinedsm_scalecauses confusion and inconvenience. To keep attention interface standard and clear, we now require the explicit passing ofsm_scale.We refactored the existing GPU unit tests of the PagedKVCache, by updating from numpy to PyTorch for std calculation. This significantly reduces the test case run time.