diff --git a/benchmarks/README.md b/benchmarks/README.md index 26a48dab5a..af8f36a1b3 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -13,6 +13,7 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati - Attention: - `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache. - Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`. + - Speculative decode is supported by setting `--s_qo > 1` (subject to backend limitations noted below). - `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache. - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`. - `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache. @@ -195,7 +196,7 @@ The output CSV will contain detailed metrics including: |--------------------------|-------------------------------------------------------------------------------------------------------------| | `--page_size` | Page size for paged attention. Required for paged attention tests. | | `--batch_size` | Number of sequences to process in parallel | -| `--s_qo` | Query/output sequence length. Should be 1 for decode tests. | +| `--s_qo` | Query/output sequence length. For decode, `1` is standard decode and `>1` enables speculative decode on supported backends. | | `--s_kv` | Key/value sequence length (context length) | | `--num_qo_heads` | Number of query/output attention heads | | `--num_kv_heads` | Number of key/value attention heads | diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 9b3e8cb206..412d8bcfcd 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -128,7 +128,7 @@ def parse_attention_args(line, parser): type=int, required=False, default=1, - help="Max sequence length of the query. Should be 1 for decode.", + help="Max sequence length of the query. For decode, 1 is standard decode and >1 enables speculative decode on supported backends.", ) parser.add_argument( "--s_kv", @@ -228,6 +228,44 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len return actual_seq_lens +def generate_speculative_causal_mask(batch_size, q_seq_len, device): + """ + Generate a packed causal mask for speculative decode chunks (q_len > 1). + + Returned shape is [batch_size, q_seq_len, num_packed_masks_per_token * 2] + with dtype uint16, where num_packed_masks_per_token = ceil(q_seq_len / 32). + Each query row i encodes allowed attention to draft-token columns j <= i + (strictly lower-triangular with diagonal) and masks out j > i. + The innermost dimension stores packed bits (uint32 words reinterpreted as + uint16), matching the mask layout expected by decode APIs. + """ + num_packed_masks_per_token = (q_seq_len + 31) // 32 + q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1) + kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0) + causal_bool_mask = kv_indices <= q_indices + + padded_seq_len = num_packed_masks_per_token * 32 + if padded_seq_len > q_seq_len: + padding = torch.zeros( + q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool + ) + causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1) + + causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32) + bit_positions = torch.tensor( + [1 << i for i in range(32)], device=device, dtype=torch.int64 + ) + mask_uint32 = ( + (causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32) + ) + mask_uint32 = ( + mask_uint32.unsqueeze(0) + .expand(batch_size, q_seq_len, num_packed_masks_per_token) + .contiguous() + ) + return mask_uint32.view(torch.uint16) + + def testBatchDecodeWithPagedKVCacheWrapper(args): """ Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API. @@ -279,6 +317,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): page_size = args.page_size batch_size = args.batch_size s_qo = args.s_qo + speculative_decode = s_qo > 1 s_kv = args.s_kv num_qo_heads = args.num_qo_heads num_kv_heads = args.num_kv_heads @@ -292,6 +331,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): # Check for backend-specific constraints if "fa2" in backends: remove_fa2 = False + if speculative_decode: + print("[INFO] FA2 backend does not support speculative decode. Skipping.") + remove_fa2 = True head_grp_size = ( num_qo_heads // num_kv_heads ) # If 5, FA2 backend is not supported. @@ -305,6 +347,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): if "fa2_tc" in backends: remove_fa2_tc = False + if speculative_decode: + print( + "[INFO] FA2_TC backend does not support speculative decode. Skipping." + ) + remove_fa2_tc = True if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, @@ -316,6 +363,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): if "cudnn" in backends: remove_cudnn = False + if speculative_decode: + print("[INFO] cuDNN backend does not support speculative decode. Skipping.") + remove_cudnn = True if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, @@ -325,6 +375,10 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): if remove_cudnn: backends.remove("cudnn") + if "auto" in backends and speculative_decode: + print("[INFO] auto backend is disabled for speculative decode. Skipping.") + backends.remove("auto") + if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return res @@ -347,7 +401,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): # Create query tensor q = torch.rand( - batch_size, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype + batch_size * s_qo, + num_qo_heads, + head_dim_qk, + device=device, + dtype=q_init_dtype, ) if args.verbose >= 2: print(f"[VVERBOSE] {q.shape = }") @@ -451,8 +509,14 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): ) ragged_q = ( - torch.arange(0, batch_size + 1, device=device) * (num_qo_heads * head_dim_qk) + torch.arange(0, batch_size + 1, device=device) + * (s_qo * num_qo_heads * head_dim_qk) ).long() # For cuDNN + speculative_mask = ( + generate_speculative_causal_mask(batch_size, s_qo, device) + if speculative_decode + else None + ) scale = float(1.0 / (head_dim_qk**0.5)) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) @@ -530,10 +594,11 @@ def run_backend_wrapper( block_tables, actual_seq_lens_kv, ragged_q, + speculative_mask, ): if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]: return backend_wrappers[backend].run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale + q, kv_cache, k_scale=k_scale, v_scale=v_scale, q_len_per_req=s_qo ) elif backend == "cudnn": return flashinfer.decode.cudnn_batch_decode_with_kv_cache( @@ -559,6 +624,10 @@ def run_backend_wrapper( max_seq_len=s_kv, bmm1_scale=scale if k_scale is None else k_scale * scale, bmm2_scale=1.0 if v_scale is None else v_scale, + kv_layout="HND", + backend="auto", + q_len_per_req=s_qo, + mask=speculative_mask, ) else: print(f"[ERROR] Backend {backend} not supported") @@ -581,6 +650,7 @@ def run_backend_wrapper( block_tables, actual_seq_lens_kv, ragged_q, + speculative_mask, ) .detach() .clone() @@ -607,6 +677,7 @@ def run_backend_wrapper( block_tables, actual_seq_lens_kv, ragged_q, + speculative_mask, ), ) @@ -643,7 +714,7 @@ def run_backend_wrapper( median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu") - actual_seq_lens_q_flat = torch.ones_like(actual_seq_lens_kv_flat) + actual_seq_lens_q_flat = torch.full_like(actual_seq_lens_kv_flat, s_qo) tflops = attention_tflops_per_sec_with_actual_seq_lens( actual_seq_lens_q_flat, actual_seq_lens_kv_flat,