Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ The output CSV will contain detailed metrics including:
| `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) |
| `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. |
| `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. |
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas|
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, auto, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas. (`auto` currently supported for `BatchDecodeWithPagedKVCacheWrapper` and `BatchPrefillWithPagedKVCacheWrapper`.)|

### Attention Flags
| Flag | Description |
Expand Down
50 changes: 40 additions & 10 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,15 @@ def parse_attention_args(line, parser):
"fa2",
"fa2_tc",
"fa3",
"auto",
"cudnn",
"cudnn-native",
"cutlass",
"trtllm-gen",
"trtllm-native",
"trtllm-gen-native", # Deprecated, will be removed in future
],
help="Kernel backends to test. Default: fa2",
help="Kernel backends to test. Default: fa2. backend=auto is only supported for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.",
)
parser.add_argument(
"--page_size",
Expand Down Expand Up @@ -196,7 +197,6 @@ def parse_attention_args(line, parser):

# Normalize backend names (handle deprecated names)
args.backends = normalize_backends(args.backends)

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Expand Down Expand Up @@ -231,7 +231,7 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len
def testBatchDecodeWithPagedKVCacheWrapper(args):
"""
Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native backends.
Supports fa2, fa2_tc, auto, cudnn, trtllm-gen, trtllm-native backends.

This test:
1. Creates paged KV cache and query tensors
Expand Down Expand Up @@ -468,8 +468,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):

# Prepare wrappers
backend_wrappers = {}
resolved_backends = {}
for backend in backends:
if backend in ["fa2", "fa2_tc", "trtllm-gen"]:
if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]:
plan_kv_indptr = (
kv_indptr.clone().detach() if backend == "trtllm-gen" else kv_indptr
)
Expand Down Expand Up @@ -498,6 +499,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
data_type=kv_dtype,
block_tables=block_tables,
)
resolved_backends[backend] = backend_wrappers[backend]._backend
else:
resolved_backends[backend] = backend

## If FP8, prepare
k_scale, v_scale = None, None
Expand Down Expand Up @@ -527,7 +531,7 @@ def run_backend_wrapper(
actual_seq_lens_kv,
ragged_q,
):
if backend in ["fa2", "fa2_tc", "trtllm-gen"]:
if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
)
Expand Down Expand Up @@ -661,7 +665,20 @@ def run_backend_wrapper(
kv_dtype=kv_dtype,
o_dtype=q_dtype,
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
resolved_backend = resolved_backends.get(backend, backend)
wrapper = backend_wrappers.get(backend)
if (
wrapper is not None
and resolved_backend == "fa2"
and wrapper.use_tensor_cores
):
resolved_backend = "fa2_tc"
display_backend = (
f"auto({resolved_backend})" if backend == "auto" else resolved_backend
)
print_perf_metrics(
display_backend, median_time, std_time, tflops, tb_per_sec
)

if args.output_path is not None:
cur_res = defaultdict(str)
Expand All @@ -671,6 +688,7 @@ def run_backend_wrapper(
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["resolved_backend"] = resolved_backend
cur_res["page_size"] = page_size
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
Expand All @@ -692,7 +710,7 @@ def run_backend_wrapper(
def testBatchPrefillWithPagedKVCacheWrapper(args):
"""
Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa3, trtllm-gen, trtllm-native, and cudnn backends.
Supports fa2, fa3, auto, trtllm-gen, trtllm-native, and cudnn backends.

This test:
1. Creates paged KV cache and query tensors for prefill
Expand Down Expand Up @@ -1029,8 +1047,9 @@ def to_float8(x, dtype=torch.float8_e4m3fn):

# Prepare wrappers (after FP8 conversion so we have correct dtypes)
backend_wrappers = {}
resolved_backends = {}
for backend in backends:
if backend in ["fa2", "fa3", "trtllm-gen"]:
if backend in ["fa2", "fa3", "auto", "trtllm-gen"]:
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
Expand Down Expand Up @@ -1060,6 +1079,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
kv_data_type=kv_dtype,
block_tables=block_tables,
)
resolved_backends[backend] = backend_wrappers[backend]._backend
elif backend == "cudnn":
# cuDNN uses NHD layout and the wrapper API
backend_wrappers[backend] = (
Expand Down Expand Up @@ -1089,6 +1109,9 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
max_sequence_kv=s_kv,
block_tables=block_tables,
)
resolved_backends[backend] = backend_wrappers[backend]._backend
else:
resolved_backends[backend] = backend

def run_backend_wrapper(
backend,
Expand All @@ -1104,7 +1127,7 @@ def run_backend_wrapper(
qo_indptr,
kv_indptr,
):
if backend in ["fa2", "fa3", "trtllm-gen"]:
if backend in ["fa2", "fa3", "auto", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale
)
Expand Down Expand Up @@ -1291,7 +1314,13 @@ def run_backend_wrapper(
kv_dtype=kv_dtype,
o_dtype=q_dtype,
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
resolved_backend = resolved_backends.get(backend, backend)
display_backend = (
f"auto({resolved_backend})" if backend == "auto" else backend
)
print_perf_metrics(
display_backend, median_time, std_time, tflops, tb_per_sec
)

if args.output_path is not None:
cur_res = defaultdict(str)
Expand All @@ -1301,6 +1330,7 @@ def run_backend_wrapper(
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["resolved_backend"] = resolved_backend
cur_res["page_size"] = page_size
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
Expand Down
31 changes: 16 additions & 15 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"tflops",
"tb_per_sec",
"backend",
"resolved_backend",
],
"attention": [
"s_qo",
Expand Down Expand Up @@ -257,26 +258,26 @@ def dtype_str_to_torch_dtype(dtype_str):
# ATTENTION
"BatchDecodeWithPagedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache
"7.5": ["fa2"],
"8.0": ["fa2", "fa2_tc", "cudnn"],
"8.6": ["fa2", "fa2_tc", "cudnn"],
"8.9": ["fa2", "fa2_tc", "cudnn"],
"9.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"],
"10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"],
"7.5": ["fa2", "auto"],
"8.0": ["fa2", "fa2_tc", "auto", "cudnn"],
"8.6": ["fa2", "fa2_tc", "auto", "cudnn"],
"8.9": ["fa2", "fa2_tc", "auto", "cudnn"],
"9.0": ["fa2", "fa2_tc", "auto", "cudnn", "trtllm-native"],
"10.0": ["fa2", "fa2_tc", "auto", "cudnn", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "fa2_tc", "auto", "cudnn", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "fa2_tc", "auto", "cudnn", "trtllm-native"],
},
"BatchPrefillWithPagedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "cudnn", "cudnn-native"],
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
"10.0": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "cudnn", "cudnn-native"],
"8.0": ["fa2", "auto", "cudnn", "cudnn-native"],
"8.6": ["fa2", "auto", "cudnn", "cudnn-native"],
"8.9": ["fa2", "auto", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native"],
"10.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "auto", "cudnn", "cudnn-native"],
},
"BatchPrefillWithRaggedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
Expand Down
22 changes: 14 additions & 8 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,14 +1040,20 @@ def plan(
self._cached_module = self._jit_module
else:
if self._backend == "auto":
self._backend = determine_attention_backend(
self.device,
PosEncodingMode[pos_encoding_mode].value,
False, # use_fp16_qk_reduction
False, # use_custom_mask
q_data_type,
kv_data_type,
)
if {
torch.float8_e4m3fn,
torch.float8_e5m2,
} & {q_data_type, kv_data_type}:
self._backend = determine_attention_backend(
self.device,
PosEncodingMode[pos_encoding_mode].value,
False, # use_fp16_qk_reductions
False, # use_custom_mask
q_data_type,
kv_data_type,
)
else:
self._backend = "fa2"
self._cached_module = get_batch_prefill_module(
self._backend,
q_data_type,
Expand Down
Loading