Commit 1053886
committed
[KVCache] PagedKVCache refactor, FlashInfer JIT and MLA integration
This PR consists of the following parts:
* We reorganized `paged_kv_cache.cc` by moving some of the utilities
to `attn_utils.h`.
* To integrate with the JIT kernel compilation in the latest FlashInfer
project, while still being able to support attention kernels written
with TIR, we introduced `AttnBackendFunc` in `attn_backend.h`, which
exposes attention interfaces (e.g., `MHA`, `MLA`) to PagedKVCache.
We subclass `AttnBackendFunc` and implement FlashInfer backends and
TIR backends respectively.
* With `AttnBackendFunc`, we refactored the PagedKVCache constructor.
The new constructor is not backward compatible, and will break the
existing compiled model libraries.
* For both TIR and FlashInfer attention implementations, now we
require an explicit attention softmax scale factor `sm_scale` to be
passed in. Previously, it has an inlined `sm_scale` of
`head_dim ** -0.5`. Due to the recent LLM inference techniques such
as MLA weight absorption in DeepSeek models, the inlined `sm_scale`
causes confusion and inconvenience. To keep attention interface
standard and clear, we now require the explicit passing of `sm_scale`.
* We refactored the existing GPU unit tests of the PagedKVCache,
by updating from numpy to PyTorch for std calculation. This
significantly reduces the test case run time.1 parent 432ccfa commit 1053886
File tree
18 files changed
+4181
-2909
lines changed- docs/how_to/tutorials
- python/tvm/relax
- backend/cuda
- frontend/nn/llm
- src/runtime/relax_vm
- tests/python/relax
- nvshmem
18 files changed
+4181
-2909
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
191 | 191 | | |
192 | 192 | | |
193 | 193 | | |
194 | | - | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
195 | 197 | | |
196 | 198 | | |
197 | 199 | | |
| |||
285 | 287 | | |
286 | 288 | | |
287 | 289 | | |
| 290 | + | |
288 | 291 | | |
289 | 292 | | |
290 | 293 | | |
| |||
294 | 297 | | |
295 | 298 | | |
296 | 299 | | |
297 | | - | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
298 | 304 | | |
299 | 305 | | |
300 | 306 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
| 94 | + | |
94 | 95 | | |
95 | 96 | | |
96 | 97 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
| |||
0 commit comments