diff --git a/flashinfer/utils.py b/flashinfer/utils.py index e2e820f221..44e8f1b762 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -408,6 +408,12 @@ def is_fa3_backend_supported( return False if use_fp16_qk_reductions: return False + # FA3 FP8 KV cache currently requires FP8 query. + if dtype_kv in {torch.float8_e4m3fn, torch.float8_e5m2} and dtype_q not in { + torch.float8_e4m3fn, + torch.float8_e5m2, + }: + return False return True