-
Notifications
You must be signed in to change notification settings - Fork 896
benchmark: Enable speculative decode microbenchmarking for paged decode #2628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -128,7 +128,7 @@ def parse_attention_args(line, parser): | |
| type=int, | ||
| required=False, | ||
| default=1, | ||
| help="Max sequence length of the query. Should be 1 for decode.", | ||
| help="Max sequence length of the query. For decode, 1 is standard decode and >1 enables speculative decode on supported backends.", | ||
| ) | ||
| parser.add_argument( | ||
| "--s_kv", | ||
|
|
@@ -228,6 +228,44 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len | |
| return actual_seq_lens | ||
|
|
||
|
|
||
| def generate_speculative_causal_mask(batch_size, q_seq_len, device): | ||
| """ | ||
| Generate a packed causal mask for speculative decode chunks (q_len > 1). | ||
|
|
||
| Returned shape is [batch_size, q_seq_len, num_packed_masks_per_token * 2] | ||
| with dtype uint16, where num_packed_masks_per_token = ceil(q_seq_len / 32). | ||
| Each query row i encodes allowed attention to draft-token columns j <= i | ||
| (strictly lower-triangular with diagonal) and masks out j > i. | ||
| The innermost dimension stores packed bits (uint32 words reinterpreted as | ||
| uint16), matching the mask layout expected by decode APIs. | ||
| """ | ||
| num_packed_masks_per_token = (q_seq_len + 31) // 32 | ||
| q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1) | ||
| kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0) | ||
| causal_bool_mask = kv_indices <= q_indices | ||
|
|
||
| padded_seq_len = num_packed_masks_per_token * 32 | ||
| if padded_seq_len > q_seq_len: | ||
| padding = torch.zeros( | ||
| q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool | ||
| ) | ||
| causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1) | ||
|
|
||
| causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32) | ||
| bit_positions = torch.tensor( | ||
| [1 << i for i in range(32)], device=device, dtype=torch.int64 | ||
| ) | ||
| mask_uint32 = ( | ||
| (causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32) | ||
| ) | ||
| mask_uint32 = ( | ||
| mask_uint32.unsqueeze(0) | ||
| .expand(batch_size, q_seq_len, num_packed_masks_per_token) | ||
| .contiguous() | ||
| ) | ||
| return mask_uint32.view(torch.uint16) | ||
|
|
||
|
|
||
| def testBatchDecodeWithPagedKVCacheWrapper(args): | ||
| """ | ||
| Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API. | ||
|
|
@@ -279,6 +317,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
| page_size = args.page_size | ||
| batch_size = args.batch_size | ||
| s_qo = args.s_qo | ||
| speculative_decode = s_qo > 1 | ||
| s_kv = args.s_kv | ||
| num_qo_heads = args.num_qo_heads | ||
| num_kv_heads = args.num_kv_heads | ||
|
|
@@ -292,6 +331,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
| # Check for backend-specific constraints | ||
| if "fa2" in backends: | ||
| remove_fa2 = False | ||
| if speculative_decode: | ||
| print("[INFO] FA2 backend does not support speculative decode. Skipping.") | ||
| remove_fa2 = True | ||
| head_grp_size = ( | ||
| num_qo_heads // num_kv_heads | ||
| ) # If 5, FA2 backend is not supported. | ||
|
|
@@ -305,6 +347,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
|
|
||
| if "fa2_tc" in backends: | ||
| remove_fa2_tc = False | ||
| if speculative_decode: | ||
| print( | ||
| "[INFO] FA2_TC backend does not support speculative decode. Skipping." | ||
| ) | ||
| remove_fa2_tc = True | ||
| if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
|
|
@@ -316,6 +363,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
|
|
||
| if "cudnn" in backends: | ||
| remove_cudnn = False | ||
| if speculative_decode: | ||
| print("[INFO] cuDNN backend does not support speculative decode. Skipping.") | ||
| remove_cudnn = True | ||
| if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
|
|
@@ -325,6 +375,10 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
| if remove_cudnn: | ||
| backends.remove("cudnn") | ||
|
|
||
| if "auto" in backends and speculative_decode: | ||
| print("[INFO] auto backend is disabled for speculative decode. Skipping.") | ||
| backends.remove("auto") | ||
|
|
||
| if len(backends) == 0: | ||
| print("[ERROR] No backends to test. Exiting.") | ||
| return res | ||
|
|
@@ -347,7 +401,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
|
|
||
| # Create query tensor | ||
| q = torch.rand( | ||
| batch_size, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype | ||
| batch_size * s_qo, | ||
| num_qo_heads, | ||
| head_dim_qk, | ||
| device=device, | ||
| dtype=q_init_dtype, | ||
| ) | ||
| if args.verbose >= 2: | ||
| print(f"[VVERBOSE] {q.shape = }") | ||
|
|
@@ -451,8 +509,14 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): | |
| ) | ||
|
|
||
| ragged_q = ( | ||
| torch.arange(0, batch_size + 1, device=device) * (num_qo_heads * head_dim_qk) | ||
| torch.arange(0, batch_size + 1, device=device) | ||
| * (s_qo * num_qo_heads * head_dim_qk) | ||
| ).long() # For cuDNN | ||
| speculative_mask = ( | ||
| generate_speculative_causal_mask(batch_size, s_qo, device) | ||
| if speculative_decode | ||
| else None | ||
| ) | ||
|
|
||
| scale = float(1.0 / (head_dim_qk**0.5)) | ||
| workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) | ||
|
|
@@ -530,10 +594,11 @@ def run_backend_wrapper( | |
| block_tables, | ||
| actual_seq_lens_kv, | ||
| ragged_q, | ||
| speculative_mask, | ||
| ): | ||
| if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]: | ||
| return backend_wrappers[backend].run( | ||
| q, kv_cache, k_scale=k_scale, v_scale=v_scale | ||
| q, kv_cache, k_scale=k_scale, v_scale=v_scale, q_len_per_req=s_qo | ||
| ) | ||
|
Comment on lines
600
to
602
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point. Will disallow backend='auto' for speculative decoding. |
||
| elif backend == "cudnn": | ||
| return flashinfer.decode.cudnn_batch_decode_with_kv_cache( | ||
|
|
@@ -559,6 +624,10 @@ def run_backend_wrapper( | |
| max_seq_len=s_kv, | ||
| bmm1_scale=scale if k_scale is None else k_scale * scale, | ||
| bmm2_scale=1.0 if v_scale is None else v_scale, | ||
| kv_layout="HND", | ||
| backend="auto", | ||
| q_len_per_req=s_qo, | ||
| mask=speculative_mask, | ||
| ) | ||
| else: | ||
| print(f"[ERROR] Backend {backend} not supported") | ||
|
|
@@ -581,6 +650,7 @@ def run_backend_wrapper( | |
| block_tables, | ||
| actual_seq_lens_kv, | ||
| ragged_q, | ||
| speculative_mask, | ||
| ) | ||
| .detach() | ||
| .clone() | ||
|
|
@@ -607,6 +677,7 @@ def run_backend_wrapper( | |
| block_tables, | ||
| actual_seq_lens_kv, | ||
| ragged_q, | ||
| speculative_mask, | ||
| ), | ||
| ) | ||
|
|
||
|
|
@@ -643,7 +714,7 @@ def run_backend_wrapper( | |
| median_time = np.median(backend_times[backend]) | ||
| std_time = np.std(backend_times[backend]) | ||
| actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu") | ||
| actual_seq_lens_q_flat = torch.ones_like(actual_seq_lens_kv_flat) | ||
| actual_seq_lens_q_flat = torch.full_like(actual_seq_lens_kv_flat, s_qo) | ||
| tflops = attention_tflops_per_sec_with_actual_seq_lens( | ||
| actual_seq_lens_q_flat, | ||
| actual_seq_lens_kv_flat, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.