-
Notifications
You must be signed in to change notification settings - Fork 581
feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
289d526
a9f8bc8
c2a0cad
21de9af
595ee1b
9c08d33
81a1afc
5186e5d
f4e1073
08d088a
5c6b9d9
869c0c1
5dc1a28
4950b67
e535e80
39e36dc
e7cca24
ed46ea9
8abb7ca
e040826
43bf624
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
yzh119 marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
|
|
||
| import torch | ||
|
|
||
| from .xqa import xqa | ||
| from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache | ||
| from .jit import ( | ||
| gen_batch_decode_mla_module, | ||
|
|
@@ -2253,6 +2254,133 @@ def trtllm_batch_decode_with_kv_cache( | |
| ) | ||
|
|
||
|
|
||
| # xqa uses NHD layout | ||
| def xqa_batch_decode_with_kv_cache( | ||
| query: torch.Tensor, | ||
| kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | ||
| workspace_buffer: torch.Tensor, | ||
| block_tables: torch.Tensor, | ||
| seq_lens: torch.Tensor, | ||
| max_seq_len: int, | ||
yzh119 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bmm1_scale: float, | ||
| bmm2_scale: float, | ||
| window_left: int = -1, | ||
| out: Optional[torch.Tensor] = None, | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sinks: Optional[torch.Tensor] = None, | ||
| enable_pdl: bool = None, | ||
| q_len_per_req: Optional[int] = 1, | ||
| ) -> torch.Tensor: | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Parameters | ||
| ---------- | ||
| query : torch.Tensor | ||
| query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request | ||
|
|
||
| kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] | ||
| If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] | ||
| If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, page_size, num_kv_heads, head_dim] | ||
|
|
||
| workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. | ||
| workspace | ||
|
|
||
| block_tables : torch.Tensor | ||
| page_table of kv cache, [batch_size, num_pages] | ||
|
|
||
| seq_lens : torch.Tensor | ||
| A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` | ||
|
|
||
| max_seq_len : int | ||
| max sequence length for kv_cache | ||
|
|
||
| bmm1_scale : float | ||
| fused scale for bmm1 input. | ||
|
|
||
| bmm2_scale : float | ||
| fused scale for bmm2 input. | ||
|
|
||
| window_left : int = -1 | ||
| The left (inclusive) window size for the attention window, when set to ``-1``, the window | ||
| size will be set to the full length of the sequence. Defaults to ``-1``. | ||
|
|
||
| out : Optional[torch.Tensor] = None | ||
| output tensor, if not provided, will be allocated with ``query.dtype``. | ||
|
|
||
| sinks : Optional[torch.Tensor] = None | ||
| additional value per head in the denominator of the softmax. | ||
|
|
||
| enable_pdl : bool | ||
| Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization | ||
| Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode. | ||
|
|
||
| Returns | ||
| ------- | ||
| out : torch.Tensor | ||
| output torch.Tensor. | ||
| """ | ||
| enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl | ||
|
|
||
| assert q_len_per_req == 1, "xqa not support speculative decoding yet" | ||
|
|
||
| if isinstance(kv_cache, tuple): | ||
| k_cache, v_cache = kv_cache | ||
| else: | ||
| if kv_cache.shape[1] == 1: | ||
| k_cache, v_cache = kv_cache, kv_cache | ||
| else: | ||
| assert kv_cache.shape[1] == 2, ( | ||
| "When kv_cache is a single tensor, the second dimension must be 1 or 2" | ||
| ) | ||
| # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) | ||
|
Comment on lines
2400
to
2409
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Squeeze the singleton KV axis before inferring shapes When if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
else:
if kv_cache.shape[1] == 1:
- k_cache, v_cache = kv_cache, kv_cache
+ k_cache = kv_cache.squeeze(1)
+ v_cache = k_cache
else:
assert kv_cache.shape[1] == 2, ( |
||
| # it doesn't change underlying storage | ||
| k_cache, v_cache = kv_cache.unbind(dim=1) | ||
|
|
||
| sm_count = get_device_sm_count(query.device) | ||
|
|
||
| bmm1_scale = ( | ||
| bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale | ||
| ) | ||
| bmm2_scale = ( | ||
| bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale | ||
| ) | ||
|
|
||
| num_kv_heads = k_cache.shape[2] | ||
| page_size = k_cache.shape[1] | ||
| head_dim = k_cache.shape[3] | ||
| workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0) | ||
| kv_scale_value = bmm2_scale | ||
| q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) | ||
|
|
||
| k_cache_new = k_cache.reshape(-1, head_dim).contiguous() | ||
| v_cache_new = v_cache.reshape(-1, head_dim).contiguous() | ||
| query_new = query.unsqueeze(1).contiguous() | ||
| seq_lens_new = seq_lens.unsqueeze(1).contiguous() | ||
| sinks_new = ( | ||
| sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None | ||
| ) | ||
|
|
||
| xqa( | ||
| query_new, | ||
| k_cache_new, | ||
| v_cache_new, | ||
| block_tables, | ||
| seq_lens_new, | ||
| out, | ||
| workspace_0, | ||
| workspace_1, | ||
| num_kv_heads, | ||
| page_size, | ||
| sinks=sinks_new, | ||
| q_scale=q_scale_value, | ||
| kv_scale=torch.tensor( | ||
| [kv_scale_value], dtype=torch.float32, device=query.device | ||
| ), | ||
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | ||
| sm_count=sm_count, | ||
| ) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return out | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _check_trtllm_gen_mla_shape( | ||
| query, | ||
| kv_cache, | ||
|
|
@@ -2410,6 +2538,7 @@ def trtllm_batch_decode_with_kv_cache_mla( | |
| workspace_buffer.numel() * workspace_buffer.element_size(), | ||
| sinks, | ||
| ) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outβofβbounds read in loadGmemColWiseVecWithDup for attention sinks
gmemVec points to a buffer of size headGrpSize (see finalizeAndWriteOut_sync passing attentionSinksVec[0]), but this code multiplies the index by GmmaAccCoreMat::cols and reads baseOffset+j, which can exceed headGrpSize. We should load a single sink value per head and duplicate it across columns, without advancing memory by cols.
Apply this fix:
π Committable suggestion