Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 121 additions & 44 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch

from .xqa import xqa, xqa_mla
from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
from .jit import (
gen_batch_decode_mla_module,
Expand Down Expand Up @@ -2222,29 +2223,64 @@ def trtllm_batch_decode_with_kv_cache(
bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale
)

run_func(
out,
out_scale_factor,
query.view(
query.size(0) // q_len_per_req, q_len_per_req, query.size(1), query.size(2)
),
k_cache,
v_cache,
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
bmm1_scale,
bmm2_scale,
o_sf_scale or -1.0,
o_sf_vec_size or -1,
o_sf_start_index,
window_left,
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
)
# To decide if using xqa to decode
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute capability check for using the xqa backend is missing support for SM100 (compute capability 10) and SM110 (compute capability 11). The xqa kernel supports SM90, SM100, SM110, and SM120. The condition should be updated to include 10 and 11 to enable the backend on all supported architectures.

Suggested change
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 11, 12]

and out_dtype != "nvfp4"
and query.dtype in [torch.float16, torch.bfloat16]
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for compute capability is missing support for SM100 (compute capability 10). The xqa implementation in flashinfer/xqa.py indicates support for SM90, SM100, and SM120, which correspond to compute capabilities 9, 10, and 12. This condition should be updated to include 10 to enable the XQA path on SM100 GPUs.

Suggested change
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
and out_dtype != "nvfp4"
and query.dtype in [torch.float16, torch.bfloat16]
):
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 12]
and out_dtype != "nvfp4"
and query.dtype in [torch.float16, torch.bfloat16]
):

num_kv_heads = k_cache.shape[1]
page_size = k_cache.shape[2]
head_dim = k_cache.shape[3]
workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
kv_scale_value = bmm2_scale
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The xqa function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa kernel. You should reinterpret the tensor view to the correct dtype before passing it.

Suggested change
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The semaphores parameter of the xqa function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.

Suggested change
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)

else:
run_func(
out,
out_scale_factor,
query.view(
query.size(0) // q_len_per_req,
q_len_per_req,
query.size(1),
query.size(2),
),
k_cache,
v_cache,
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
bmm1_scale,
bmm2_scale,
o_sf_scale or -1.0,
o_sf_vec_size or -1,
o_sf_start_index,
window_left,
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
)

return (
out
Expand Down Expand Up @@ -2389,27 +2425,68 @@ def trtllm_batch_decode_with_kv_cache_mla(
"Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
)

run_func(
out,
None, # fp4 output not supported in wrapper api yet.
query,
kv_cache,
kv_cache,
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
bmm1_scale,
bmm2_scale,
-1, # o_sf_scale
-1, # o_sf_vec_size
0, # o_sf_start_index
-1, # window_left
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
)
# To decide if using xqa_mla to decode
if (
get_compute_capability(torch.device(device="cuda"))[0] == 12
and query.dtype == torch.float8_e4m3fn
and kv_cache.dtype == torch.float8_e4m3fn
and sinks is None
):
workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scale calculation for q_scale_value uses a hardcoded value (128 + 64). This corresponds to qk_nope_head_dim + qk_rope_head_dim. Using the function arguments qk_nope_head_dim and qk_rope_head_dim will improve code readability and maintainability.

Suggested change
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
if bmm1_scale is not None
else 1.0
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The head dimension is hardcoded as (128 + 64). This makes the code less maintainable. You should use the function parameters qk_nope_head_dim and qk_rope_head_dim instead, which are available in this function's scope.

Suggested change
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
if bmm1_scale is not None
else 1.0
)


xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The xqa_mla function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa_mla kernel. You should reinterpret the tensor view to the correct dtype before passing it.

Suggested change
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The semaphores parameter of the xqa_mla function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.

Suggested change
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)

else:
run_func(
out,
None, # fp4 output not supported in wrapper api yet.
query,
kv_cache,
kv_cache,
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
bmm1_scale,
bmm2_scale,
-1, # o_sf_scale
-1, # o_sf_vec_size
0, # o_sf_start_index
-1, # window_left
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
)
return out


Expand Down