Skip to content

Commit 018b551

Browse files
authored
feat: Add fp8-qkv, fp16/bf16 output MHA (#1540)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 27f9c8f commit 018b551

File tree

7 files changed

+53
-41
lines changed

7 files changed

+53
-41
lines changed

flashinfer/decode.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
_check_cached_qkv_data_type,
4949
_check_kv_layout,
5050
_check_pos_encoding_mode,
51-
_check_shape_dtype_device,
51+
check_shape_dtype_device,
5252
_get_cache_alibi_slopes_buf,
5353
_get_cache_buf,
5454
_get_range_buf,
@@ -1229,14 +1229,14 @@ def run(
12291229
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
12301230
)
12311231
else:
1232-
_check_shape_dtype_device(
1232+
check_shape_dtype_device(
12331233
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
12341234
)
12351235

12361236
if out is None:
12371237
out = torch.empty_like(q)
12381238
else:
1239-
_check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
1239+
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
12401240

12411241
if self.use_tensor_cores:
12421242
run_args = [
@@ -1747,7 +1747,7 @@ def run(
17471747
if out is None:
17481748
out = torch.empty_like(q_nope, device=device)
17491749
else:
1750-
_check_shape_dtype_device(
1750+
check_shape_dtype_device(
17511751
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
17521752
)
17531753

@@ -1759,7 +1759,7 @@ def run(
17591759
device=device,
17601760
)
17611761
else:
1762-
_check_shape_dtype_device(
1762+
check_shape_dtype_device(
17631763
lse,
17641764
(q_nope.size(0), q_nope.size(1)),
17651765
q_nope.dtype,
@@ -2107,9 +2107,9 @@ def trtllm_batch_decode_with_kv_cache(
21072107
assert isinstance(out, torch.Tensor)
21082108

21092109
# Use uint8 as the container dtype to compliant with next fp4 gemm.
2110-
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
2110+
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
21112111

2112-
_check_shape_dtype_device(
2112+
check_shape_dtype_device(
21132113
out_scale_factor,
21142114
fp4_out_scale_shape,
21152115
torch.float8_e4m3fn,
@@ -2135,7 +2135,9 @@ def trtllm_batch_decode_with_kv_cache(
21352135
o_sf_start_index = 0
21362136
out_dtype = out_dtype or query.dtype
21372137
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
2138-
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
2138+
if out_dtype not in (query.dtype, torch.float16, torch.bfloat16):
2139+
raise ValueError(f"Unsupported out_dtype: {out_dtype}")
2140+
check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out")
21392141
else:
21402142
raise ValueError(f"Invalid out_dtype: {out_dtype}")
21412143

@@ -2288,7 +2290,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
22882290
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
22892291
else:
22902292
batch_size, _, num_q_heads, _ = query.shape
2291-
_check_shape_dtype_device(
2293+
check_shape_dtype_device(
22922294
out,
22932295
[batch_size, num_q_heads, kv_lora_rank],
22942296
torch.bfloat16,

flashinfer/fused_moe/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ..jit.cubin_loader import get_cubin
3737
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
3838
from ..utils import (
39-
_check_shape_dtype_device,
39+
check_shape_dtype_device,
4040
device_support_pdl,
4141
get_shuffle_matrix_a_row_indices,
4242
get_shuffle_matrix_sf_a_row_indices,
@@ -868,7 +868,7 @@ def cutlass_fused_moe(
868868
if output is None:
869869
output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
870870
else:
871-
_check_shape_dtype_device(
871+
check_shape_dtype_device(
872872
output, output_shape, output_dtype, input.device, "output"
873873
)
874874

flashinfer/mla.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .jit import JitSpec
2323
from .jit import env as jit_env
2424
from .jit import gen_batch_mla_module, gen_jit_spec, sm100a_nvcc_flags
25-
from .utils import MaskMode, _check_shape_dtype_device, determine_mla_backend
25+
from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend
2626

2727

2828
def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):
@@ -394,7 +394,7 @@ def run(
394394
if out is None:
395395
out = torch.empty_like(q_nope)
396396
else:
397-
_check_shape_dtype_device(
397+
check_shape_dtype_device(
398398
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
399399
)
400400
q_nope_pe = torch.cat([q_nope, q_pe], dim=-1)
@@ -426,15 +426,15 @@ def run(
426426
if out is None:
427427
out = torch.empty_like(q_nope)
428428
else:
429-
_check_shape_dtype_device(
429+
check_shape_dtype_device(
430430
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
431431
)
432432

433433
if return_lse:
434434
if lse is None:
435435
lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device)
436436
else:
437-
_check_shape_dtype_device(
437+
check_shape_dtype_device(
438438
lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse"
439439
)
440440
profiler_args = (profiler_buffer,) if self._use_profiler else ()

flashinfer/prefill.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
_check_cached_qkv_data_type,
4444
_check_kv_layout,
4545
_check_pos_encoding_mode,
46-
_check_shape_dtype_device,
46+
check_shape_dtype_device,
4747
_get_cache_alibi_slopes_buf,
4848
_get_cache_buf,
4949
_unpack_paged_kv_cache,
@@ -2032,7 +2032,7 @@ def run(
20322032
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
20332033
)
20342034
else:
2035-
_check_shape_dtype_device(
2035+
check_shape_dtype_device(
20362036
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
20372037
)
20382038

@@ -2041,7 +2041,7 @@ def run(
20412041
q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device
20422042
)
20432043
else:
2044-
_check_shape_dtype_device(
2044+
check_shape_dtype_device(
20452045
out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out"
20462046
)
20472047

@@ -2831,15 +2831,15 @@ def run(
28312831
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
28322832
)
28332833
else:
2834-
_check_shape_dtype_device(
2834+
check_shape_dtype_device(
28352835
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
28362836
)
28372837
if out is None:
28382838
out = torch.empty(
28392839
q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device
28402840
)
28412841
else:
2842-
_check_shape_dtype_device(
2842+
check_shape_dtype_device(
28432843
out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out"
28442844
)
28452845
if self._backend == "cutlass":
@@ -3365,9 +3365,9 @@ def trtllm_batch_context_with_kv_cache(
33653365
assert isinstance(out, torch.Tensor)
33663366

33673367
# Use uint8 as the container dtype to compliant with next fp4 gemm.
3368-
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
3368+
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
33693369

3370-
_check_shape_dtype_device(
3370+
check_shape_dtype_device(
33713371
out_scale_factor,
33723372
fp4_out_scale_shape,
33733373
torch.float8_e4m3fn,
@@ -3392,8 +3392,10 @@ def trtllm_batch_context_with_kv_cache(
33923392
out_scale_factor = None
33933393
o_sf_start_index = 0
33943394
out_dtype = out_dtype or query.dtype
3395+
if out_dtype not in (query.dtype, torch.float16, torch.bfloat16):
3396+
raise ValueError(f"Unsupported out_dtype: {out_dtype}")
33953397
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
3396-
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
3398+
check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out")
33973399
else:
33983400
raise ValueError(f"Invalid out_dtype: {out_dtype}")
33993401

flashinfer/sparse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
PosEncodingMode,
2929
TensorLayout,
3030
_check_pos_encoding_mode,
31-
_check_shape_dtype_device,
31+
check_shape_dtype_device,
3232
_get_cache_alibi_slopes_buf,
3333
canonicalize_torch_dtype,
3434
determine_attention_backend,
@@ -577,14 +577,14 @@ def run(
577577
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
578578
)
579579
else:
580-
_check_shape_dtype_device(
580+
check_shape_dtype_device(
581581
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
582582
)
583583

584584
if out is None:
585585
out = torch.empty_like(q, dtype=self._o_dtype)
586586
else:
587-
_check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
587+
check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
588588

589589
if is_float8(q):
590590
assert q.dtype == k.dtype == v.dtype
@@ -1157,14 +1157,14 @@ def run(
11571157
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
11581158
)
11591159
else:
1160-
_check_shape_dtype_device(
1160+
check_shape_dtype_device(
11611161
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
11621162
)
11631163

11641164
if out is None:
11651165
out = torch.empty_like(q, dtype=self._o_dtype)
11661166
else:
1167-
_check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
1167+
check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
11681168

11691169
if self._backend == "fa3":
11701170
if (

flashinfer/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,22 +443,22 @@ def determine_mla_backend(device: torch.device) -> str:
443443
return "fa3" if is_sm90a_supported(device) else "fa2"
444444

445445

446-
def _check_shape_dtype_device(
446+
def check_shape_dtype_device(
447447
x: torch.Tensor,
448-
expected_shape: Sequence[int],
449-
expected_dtype: torch.dtype,
450-
expected_device: torch.device,
448+
expected_shape: Optional[Sequence[int]],
449+
expected_dtype: Optional[torch.dtype],
450+
expected_device: Optional[torch.device],
451451
name: str,
452452
) -> None:
453-
if x.shape != torch.Size(expected_shape):
453+
if expected_shape and x.shape != torch.Size(expected_shape):
454454
raise ValueError(
455455
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
456456
)
457-
if x.dtype != expected_dtype:
457+
if expected_dtype and x.dtype != expected_dtype:
458458
raise ValueError(
459459
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
460460
)
461-
if x.device != expected_device:
461+
if expected_device and x.device != expected_device:
462462
raise ValueError(
463463
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
464464
)

tests/test_trtllm_gen_attention.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from flashinfer.utils import FP4Tensor, ceil_div, round_up
99

1010
DTYPE_MAP = {
11-
"half": torch.float16,
11+
"fp16": torch.float16,
1212
"bf16": torch.bfloat16,
1313
"fp8": torch.float8_e4m3fn,
1414
"nvfp4": "nvfp4",
@@ -237,8 +237,10 @@ def unpack_compare_nvfp4(
237237
@pytest.mark.parametrize(
238238
"q_dtype,kv_dtype,o_dtype",
239239
[
240-
("half", "half", "half"),
241240
("bf16", "bf16", "bf16"),
241+
("fp16", "fp16", "fp16"),
242+
("fp8", "fp8", "bf16"),
243+
("fp8", "fp8", "fp16"),
242244
("fp8", "fp8", "fp8"),
243245
("fp8", "fp8", "nvfp4"),
244246
],
@@ -355,8 +357,10 @@ def test_trtllm_batch_prefill(
355357
)
356358
assert o_scale == 1.0
357359
rtol, atol = 4e-1, 1e0
358-
elif o_dtype == "fp8":
360+
elif q_dtype == "fp8" and o_dtype == "fp8":
359361
rtol, atol = 5e-2, 7e-2
362+
elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]:
363+
rtol, atol = 4e-2, 6e-2
360364
else:
361365
rtol, atol = 1e-2, 1e-2
362366

@@ -399,10 +403,12 @@ def test_trtllm_batch_prefill(
399403
@pytest.mark.parametrize(
400404
"q_dtype,kv_dtype,o_dtype",
401405
[
402-
("half", "half", "half"),
403-
("half", "fp8", "half"),
404406
("bf16", "bf16", "bf16"),
407+
("fp16", "fp16", "fp16"),
405408
("bf16", "fp8", "bf16"),
409+
("fp16", "fp8", "fp16"),
410+
("fp8", "fp8", "bf16"),
411+
("fp8", "fp8", "fp16"),
406412
("fp8", "fp8", "fp8"),
407413
("fp8", "fp8", "nvfp4"),
408414
],
@@ -512,8 +518,10 @@ def test_trtllm_batch_decode(
512518
)
513519
assert o_scale == 1.0
514520
rtol, atol = 3e-1, 1e0
515-
elif o_dtype == "fp8":
521+
elif q_dtype == "fp8" and o_dtype == "fp8":
516522
rtol, atol = 5e-2, 7e-2
523+
elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]:
524+
rtol, atol = 4e-2, 6e-2
517525
else:
518526
rtol, atol = 1e-2, 1e-2
519527

0 commit comments

Comments
 (0)