Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions benchmarks/bench_trtllm_fmha.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import torch

import flashinfer
import flashinfer.decode
from flashinfer.testing.utils import bench_gpu_time, bench_gpu_time_with_cudagraph
from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache

page_size = 16
num_kv_heads = 4
Expand Down Expand Up @@ -112,9 +114,30 @@ def bench_trtllm_fmha_wrapper(
) # Random permutation

kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape).to(q.dtype)
kv_cache = torch.randn(size=kv_cache_shape, device=device).to(q.dtype)

# Prepare dtype-specific KV cache and scales
kv_block_scales = None
k_scale_val = None
v_scale_val = None
if kv_cache_dtype == "nvfp4":
# NVFP4 KV requires FP8 query β€” auto-convert if needed
if q.dtype != torch.float8_e4m3fn:
q, q_inv_scale = to_float8(q)
k_scale_val = (
q_inv_scale.item()
if isinstance(q_inv_scale, torch.Tensor)
else q_inv_scale
)
else:
k_scale_val = 1.0

if kv_cache_dtype.startswith("fp8") and q_dtype != "fp8":
kv_cache, kv_block_scales, k_gs, v_gs = nvfp4_quantize_paged_kv_cache(
kv_cache[:, 0], kv_cache[:, 1]
)
k_scale_val *= k_gs
v_scale_val = v_gs
elif kv_cache_dtype.startswith("fp8") and q_dtype != "fp8":
kv_cache, _ = to_float8(kv_cache)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
Expand Down Expand Up @@ -155,22 +178,48 @@ def bench_trtllm_fmha_wrapper(
head_dim,
page_size,
pos_encoding_mode="NONE",
data_type=kv_cache.dtype,
q_data_type=q.dtype,
kv_data_type=torch.uint8 if kv_cache_dtype == "nvfp4" else kv_cache.dtype,
window_left=window_left,
)

# add one warmup here
wrapper.run(q, kv_cache, sinks=sinks)
wrapper.run(
q,
kv_cache,
sinks=sinks,
k_scale=k_scale_val,
v_scale=v_scale_val,
kv_block_scales=kv_block_scales,
)
torch.cuda.synchronize()

measurements = bench_gpu_time_with_cudagraph(
lambda: wrapper.run(q, kv_cache, sinks=sinks),
lambda: wrapper.run(
q,
kv_cache,
sinks=sinks,
k_scale=k_scale_val,
v_scale=v_scale_val,
kv_block_scales=kv_block_scales,
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()
if isinstance(kv_cache, tuple):
io = (
q.numel() * q.element_size()
+ kv_cache[0].numel() * kv_cache[0].element_size()
+ kv_cache[1].numel() * kv_cache[1].element_size()
)
else:
io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()
if kv_block_scales is not None:
io += (
kv_block_scales[0].numel() * kv_block_scales[0].element_size()
+ kv_block_scales[1].numel() * kv_block_scales[1].element_size()
)
print(
f"batch_size={batch_size}, seq_len={max_seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_size={page_size}"
)
Expand Down Expand Up @@ -198,6 +247,13 @@ def bench_trtllm_fmha_wrapper(
help="Number of query heads per key-value head (group size)",
)
parser.add_argument("--sink", action="store_true", help="Whether to test with sink")
parser.add_argument(
"--kv_cache_dtype",
type=str,
default="auto",
choices=["auto", "fp8", "nvfp4"],
help="KV cache dtype [auto, fp8, nvfp4]",
)
parser.add_argument(
"--batch_sizes",
type=int,
Expand Down Expand Up @@ -226,7 +282,7 @@ def bench_trtllm_fmha_wrapper(
head_dim=args.head_dim,
q_dtype="bf16",
head_grp_size=args.head_grp_size,
kv_cache_dtype="auto",
kv_cache_dtype=args.kv_cache_dtype,
window_left=-1,
bench_with_sink=args.sink,
)
56 changes: 43 additions & 13 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

import flashinfer
import flashinfer.decode

# Try to import cudnn for version checking
CUDNN_AVAILABLE = False
Expand All @@ -20,6 +21,7 @@
is_lib_missing = any(ext in error_msg for ext in [".so", ".dll"])
if not is_lib_missing:
raise
from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
Expand Down Expand Up @@ -314,8 +316,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
return res

# Handle different KV cache data types.
is_nvfp4_kv = args.kv_dtype == "nvfp4"
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.uint8]:
print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}")
return res

Expand All @@ -334,7 +337,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim_qk = args.head_dim_qk
head_dim_vo = args.head_dim_vo
head_dim_vo = args.head_dim_vo if args.head_dim_vo is not None else head_dim_qk
is_cuda_graph_compatible = not args.no_cuda_graph
# return_lse = not args.no_lse # TO-DO: Add support for this
run_refcheck = args.refcheck
Expand Down Expand Up @@ -391,10 +394,6 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
print("[INFO] auto backend is disabled for speculative decode. Skipping.")
backends.remove("auto")

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res

# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
Expand Down Expand Up @@ -579,11 +578,21 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
else:
resolved_backends[backend] = backend

## If FP8, prepare
## Prepare dtype-specific data
k_scale, v_scale = None, None
kv_block_scales = None
kv_cache_nvfp4 = None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if is_nvfp4_kv:
# NVFP4 KV requires FP8 query
if q_dtype != torch.float8_e4m3fn:
print("[ERROR] NVFP4 KV cache requires --q_dtype fp8_e4m3.")
return res
kv_cache_nvfp4, kv_block_scales, k_scale, v_scale = (
nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1])
)
elif kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_scale = k_data.amax().item() / 256
v_scale = v_data.amax().item() / 256
Expand All @@ -609,8 +618,14 @@ def run_backend_wrapper(
speculative_mask,
):
if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]:
wrapper_kv = kv_cache_nvfp4 if is_nvfp4_kv else kv_cache
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale, q_len_per_req=s_qo
q,
wrapper_kv,
k_scale=k_scale,
v_scale=v_scale,
q_len_per_req=s_qo,
kv_block_scales=kv_block_scales,
)
elif backend == "cudnn":
return flashinfer.decode.cudnn_batch_decode_with_kv_cache(
Expand All @@ -627,9 +642,10 @@ def run_backend_wrapper(
batch_offsets_o=ragged_q,
)
elif backend == "trtllm-native":
native_kv = kv_cache_nvfp4 if is_nvfp4_kv else kv_cache
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q.contiguous(),
kv_cache=kv_cache,
kv_cache=native_kv,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv,
Expand All @@ -640,6 +656,7 @@ def run_backend_wrapper(
backend="auto",
q_len_per_req=s_qo,
mask=speculative_mask,
kv_block_scales=kv_block_scales,
)
else:
print(f"[ERROR] Backend {backend} not supported")
Expand Down Expand Up @@ -829,8 +846,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
print(f"[ERROR] Unsupported q_dtype: {args.q_dtype}")
return res

is_nvfp4_kv = args.kv_dtype == "nvfp4"
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.uint8]:
print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}")
return res

Expand Down Expand Up @@ -1108,6 +1126,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
# Compute scales and convert to FP8 if needed (before creating wrappers)
q_scale, k_scale, v_scale = None, None, None
q_scale_tensor, k_scale_tensor, v_scale_tensor = None, None, None
kv_block_scales = None
o_data_type = q_dtype # Default output dtype
# Separate K/V caches for cuDNN (which requires separate tensors, not combined kv_cache)
k_cache_cudnn, v_cache_cudnn = k_cache, v_cache
Expand All @@ -1118,7 +1137,12 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
q_scale_tensor = q_scale_t.reshape(1, 1, 1, 1)
# o_data_type stays as q_dtype (FP8 output)

if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if is_nvfp4_kv:
kv_cache_nvfp4, kv_block_scales, k_scale, v_scale = (
nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1])
)
kv_cache = kv_cache_nvfp4
elif kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# Convert k_cache and v_cache to quantized dtype for cuDNN
k_cache_cudnn, k_scale_t = to_float8(k_cache, kv_dtype)
v_cache_cudnn, v_scale_t = to_float8(v_cache, kv_dtype)
Expand Down Expand Up @@ -1217,7 +1241,12 @@ def run_backend_wrapper(
):
if backend in ["fa2", "fa3", "auto", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale
q,
kv_cache,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
kv_block_scales=kv_block_scales,
)
elif backend == "cudnn":
# cuDNN uses wrapper API with tensor scales for FP8
Expand Down Expand Up @@ -1249,6 +1278,7 @@ def run_backend_wrapper(
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
kv_block_scales=kv_block_scales,
)
elif backend == "cudnn-native":
# Direct cudnn_batch_prefill_with_kv_cache call (similar to trtllm-native)
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def dtype_str_to_torch_dtype(dtype_str):
return torch.float8_e4m3fn
elif dtype_str == "fp8_e5m2":
return torch.float8_e5m2
elif dtype_str == "nvfp4":
return torch.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")

Expand Down
Loading
Loading