Skip to content

Commit 1053886

Browse files
committed
[KVCache] PagedKVCache refactor, FlashInfer JIT and MLA integration
This PR consists of the following parts: * We reorganized `paged_kv_cache.cc` by moving some of the utilities to `attn_utils.h`. * To integrate with the JIT kernel compilation in the latest FlashInfer project, while still being able to support attention kernels written with TIR, we introduced `AttnBackendFunc` in `attn_backend.h`, which exposes attention interfaces (e.g., `MHA`, `MLA`) to PagedKVCache. We subclass `AttnBackendFunc` and implement FlashInfer backends and TIR backends respectively. * With `AttnBackendFunc`, we refactored the PagedKVCache constructor. The new constructor is not backward compatible, and will break the existing compiled model libraries. * For both TIR and FlashInfer attention implementations, now we require an explicit attention softmax scale factor `sm_scale` to be passed in. Previously, it has an inlined `sm_scale` of `head_dim ** -0.5`. Due to the recent LLM inference techniques such as MLA weight absorption in DeepSeek models, the inlined `sm_scale` causes confusion and inconvenience. To keep attention interface standard and clear, we now require the explicit passing of `sm_scale`. * We refactored the existing GPU unit tests of the PagedKVCache, by updating from numpy to PyTorch for std calculation. This significantly reduces the test case run time.
1 parent 432ccfa commit 1053886

18 files changed

+4181
-2909
lines changed

docs/how_to/tutorials/optimize_llm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
191191
qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))
192192
# Attention
193193
output = op.reshape(
194-
paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads),
194+
paged_kv_cache.attention_with_fused_qkv(
195+
layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5
196+
),
195197
(b, s, h_q * d),
196198
)
197199
# Output Projection
@@ -285,6 +287,7 @@ def create_tir_paged_kv_cache(
285287
page_size: tir.Var,
286288
) -> PagedKVCache:
287289
return TIRPagedKVCache(
290+
attn_kind="mha",
288291
max_batch_size=max_batch_size,
289292
max_total_seq_len=max_total_seq_len,
290293
prefill_chunk_size=prefill_chunk_size,
@@ -294,7 +297,10 @@ def create_tir_paged_kv_cache(
294297
num_hidden_layers=self.num_hidden_layers,
295298
num_attention_heads=self.num_attention_heads,
296299
num_key_value_heads=self.num_key_value_heads,
297-
head_dim=self.head_dim,
300+
qk_head_dim=self.head_dim,
301+
v_head_dim=self.head_dim,
302+
mla_original_qk_head_dim=0,
303+
mla_original_v_head_dim=0,
298304
rope_mode=RopeMode.NORMAL,
299305
rope_scale=1,
300306
rope_theta=self.rope_theta,

python/tvm/relax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
)
9292

9393
# pipeline
94+
from .pipeline import get_default_pipeline
9495
from .pipeline import get_pipeline
9596
from .pipeline import register_pipeline
9697

python/tvm/relax/backend/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""The Relax CUDA backend compilation pipeline and other passes."""
18+
from . import flashinfer
1819
from .pipeline import (
1920
finalize_passes,
2021
get_default_pipeline,

0 commit comments

Comments
 (0)