-
Notifications
You must be signed in to change notification settings - Fork 580
feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
289d526
a9f8bc8
c2a0cad
21de9af
595ee1b
9c08d33
81a1afc
5186e5d
f4e1073
08d088a
5c6b9d9
869c0c1
5dc1a28
4950b67
e535e80
39e36dc
e7cca24
ed46ea9
8abb7ca
e040826
43bf624
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .xqa import xqa, xqa_mla | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .jit import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gen_batch_decode_mla_module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2222,29 +2223,64 @@ def trtllm_batch_decode_with_kv_cache( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_scale_factor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query.view( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query.size(0) // q_len_per_req, q_len_per_req, query.size(1), query.size(2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k_cache, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_cache, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace_buffer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_tables, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seq_lens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_len, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bmm1_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bmm2_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_sf_scale or -1.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_sf_vec_size or -1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_sf_start_index, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| window_left, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sm_count, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_pdl, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace_buffer.numel() * workspace_buffer.element_size(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sinks, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # To decide if using xqa to decode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] | |
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 11, 12] |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for compute capability is missing support for SM100 (compute capability 10). The xqa implementation in flashinfer/xqa.py indicates support for SM90, SM100, and SM120, which correspond to compute capabilities 9, 10, and 12. This condition should be updated to include 10 to enable the XQA path on SM100 GPUs.
| if ( | |
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] | |
| and out_dtype != "nvfp4" | |
| and query.dtype in [torch.float16, torch.bfloat16] | |
| ): | |
| if ( | |
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 12] | |
| and out_dtype != "nvfp4" | |
| and query.dtype in [torch.float16, torch.bfloat16] | |
| ): |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The xqa function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa kernel. You should reinterpret the tensor view to the correct dtype before passing it.
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) | |
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semaphores parameter of the xqa function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) | |
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) |
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scale calculation for q_scale_value uses a hardcoded value (128 + 64). This corresponds to qk_nope_head_dim + qk_rope_head_dim. Using the function arguments qk_nope_head_dim and qk_rope_head_dim will improve code readability and maintainability.
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((128 + 64) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) | |
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The head dimension is hardcoded as (128 + 64). This makes the code less maintainable. You should use the function parameters qk_nope_head_dim and qk_rope_head_dim instead, which are available in this function's scope.
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((128 + 64) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) | |
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The xqa_mla function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa_mla kernel. You should reinterpret the tensor view to the correct dtype before passing it.
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) | |
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semaphores parameter of the xqa_mla function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) | |
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) |
Uh oh!
There was an error while loading. Please reload this page.