Skip to content

Conversation

@MasterJH5574
Copy link
Contributor

This PR consists of the following parts:

  • We reorganized paged_kv_cache.cc by moving some of the utilities to attn_utils.h.

  • To integrate with the JIT kernel compilation in the latest FlashInfer project, while still being able to support attention kernels written with TIR, we introduced AttnBackendFunc in attn_backend.h, which exposes attention interfaces (e.g., MHA, MLA) to PagedKVCache. We subclass AttnBackendFunc and implement FlashInfer backends and TIR backends respectively.

  • With AttnBackendFunc, we refactored the PagedKVCache constructor. The new constructor is not backward compatible, and will break the existing compiled model libraries.

  • For both TIR and FlashInfer attention implementations, now we require an explicit attention softmax scale factor sm_scale to be passed in. Previously, it has an inlined sm_scale of head_dim ** -0.5. Due to the recent LLM inference techniques such as MLA weight absorption in DeepSeek models, the inlined sm_scale causes confusion and inconvenience. To keep attention interface standard and clear, we now require the explicit passing of sm_scale.

  • We refactored the existing GPU unit tests of the PagedKVCache, by updating from numpy to PyTorch for std calculation. This significantly reduces the test case run time.

@MasterJH5574 MasterJH5574 force-pushed the tvm-dev/2025-02-24-attn-jit branch 4 times, most recently from 1053886 to 8bf03d0 Compare February 26, 2025 14:54
This PR consists of the following parts:

* We reorganized `paged_kv_cache.cc` by moving some of the utilities
to `attn_utils.h`.

* To integrate with the JIT kernel compilation in the latest FlashInfer
project, while still being able to support attention kernels written
with TIR, we introduced `AttnBackendFunc` in `attn_backend.h`, which
exposes attention interfaces (e.g., `MHA`, `MLA`) to PagedKVCache.
We subclass `AttnBackendFunc` and implement FlashInfer backends and
TIR backends respectively.

* With `AttnBackendFunc`, we refactored the PagedKVCache constructor.
The new constructor is not backward compatible, and will break the
existing compiled model libraries.

* For both TIR and FlashInfer attention implementations, now we
require an explicit attention softmax scale factor `sm_scale` to be
passed in. Previously, it has an inlined `sm_scale` of
`head_dim ** -0.5`. Due to the recent LLM inference techniques such
as MLA weight absorption in DeepSeek models, the inlined `sm_scale`
causes confusion and inconvenience. To keep attention interface
standard and clear, we now require the explicit passing of `sm_scale`.

* We refactored the existing GPU unit tests of the PagedKVCache,
by updating from numpy to PyTorch for std calculation. This
significantly reduces the test case run time.
@MasterJH5574 MasterJH5574 force-pushed the tvm-dev/2025-02-24-attn-jit branch from 8bf03d0 to d7fb5e4 Compare February 26, 2025 16:44
@Hzfengsy Hzfengsy merged commit 61f6e7f into apache:main Feb 27, 2025
15 checks passed
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
…pache#17674)

This PR consists of the following parts:

* We reorganized `paged_kv_cache.cc` by moving some of the utilities
to `attn_utils.h`.

* To integrate with the JIT kernel compilation in the latest FlashInfer
project, while still being able to support attention kernels written
with TIR, we introduced `AttnBackendFunc` in `attn_backend.h`, which
exposes attention interfaces (e.g., `MHA`, `MLA`) to PagedKVCache.
We subclass `AttnBackendFunc` and implement FlashInfer backends and
TIR backends respectively.

* With `AttnBackendFunc`, we refactored the PagedKVCache constructor.
The new constructor is not backward compatible, and will break the
existing compiled model libraries.

* For both TIR and FlashInfer attention implementations, now we
require an explicit attention softmax scale factor `sm_scale` to be
passed in. Previously, it has an inlined `sm_scale` of
`head_dim ** -0.5`. Due to the recent LLM inference techniques such
as MLA weight absorption in DeepSeek models, the inlined `sm_scale`
causes confusion and inconvenience. To keep attention interface
standard and clear, we now require the explicit passing of `sm_scale`.

* We refactored the existing GPU unit tests of the PagedKVCache,
by updating from numpy to PyTorch for std calculation. This
significantly reduces the test case run time.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants