Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati
- Attention:
- `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache.
- Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`.
- Speculative decode is supported by setting `--s_qo > 1` (subject to backend limitations noted below).
- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`.
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache.
Expand Down Expand Up @@ -195,7 +196,7 @@ The output CSV will contain detailed metrics including:
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--page_size` | Page size for paged attention. Required for paged attention tests. |
| `--batch_size` | Number of sequences to process in parallel |
| `--s_qo` | Query/output sequence length. Should be 1 for decode tests. |
| `--s_qo` | Query/output sequence length. For decode, `1` is standard decode and `>1` enables speculative decode on supported backends. |
| `--s_kv` | Key/value sequence length (context length) |
| `--num_qo_heads` | Number of query/output attention heads |
| `--num_kv_heads` | Number of key/value attention heads |
Expand Down
81 changes: 76 additions & 5 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def parse_attention_args(line, parser):
type=int,
required=False,
default=1,
help="Max sequence length of the query. Should be 1 for decode.",
help="Max sequence length of the query. For decode, 1 is standard decode and >1 enables speculative decode on supported backends.",
)
parser.add_argument(
"--s_kv",
Expand Down Expand Up @@ -228,6 +228,44 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len
return actual_seq_lens


def generate_speculative_causal_mask(batch_size, q_seq_len, device):
"""
Generate a packed causal mask for speculative decode chunks (q_len > 1).

Returned shape is [batch_size, q_seq_len, num_packed_masks_per_token * 2]
with dtype uint16, where num_packed_masks_per_token = ceil(q_seq_len / 32).
Each query row i encodes allowed attention to draft-token columns j <= i
(strictly lower-triangular with diagonal) and masks out j > i.
The innermost dimension stores packed bits (uint32 words reinterpreted as
uint16), matching the mask layout expected by decode APIs.
"""
num_packed_masks_per_token = (q_seq_len + 31) // 32
q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1)
kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0)
causal_bool_mask = kv_indices <= q_indices

padded_seq_len = num_packed_masks_per_token * 32
if padded_seq_len > q_seq_len:
padding = torch.zeros(
q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool
)
causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1)

causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32)
bit_positions = torch.tensor(
[1 << i for i in range(32)], device=device, dtype=torch.int64
)
mask_uint32 = (
(causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32)
)
mask_uint32 = (
mask_uint32.unsqueeze(0)
.expand(batch_size, q_seq_len, num_packed_masks_per_token)
.contiguous()
)
return mask_uint32.view(torch.uint16)


def testBatchDecodeWithPagedKVCacheWrapper(args):
"""
Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API.
Expand Down Expand Up @@ -279,6 +317,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
page_size = args.page_size
batch_size = args.batch_size
s_qo = args.s_qo
speculative_decode = s_qo > 1
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
Expand All @@ -292,6 +331,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
# Check for backend-specific constraints
if "fa2" in backends:
remove_fa2 = False
if speculative_decode:
print("[INFO] FA2 backend does not support speculative decode. Skipping.")
remove_fa2 = True
head_grp_size = (
num_qo_heads // num_kv_heads
) # If 5, FA2 backend is not supported.
Expand All @@ -305,6 +347,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):

if "fa2_tc" in backends:
remove_fa2_tc = False
if speculative_decode:
print(
"[INFO] FA2_TC backend does not support speculative decode. Skipping."
)
remove_fa2_tc = True
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
Expand All @@ -316,6 +363,9 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):

if "cudnn" in backends:
remove_cudnn = False
if speculative_decode:
print("[INFO] cuDNN backend does not support speculative decode. Skipping.")
remove_cudnn = True
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
Expand All @@ -325,6 +375,10 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
if remove_cudnn:
backends.remove("cudnn")

if "auto" in backends and speculative_decode:
print("[INFO] auto backend is disabled for speculative decode. Skipping.")
backends.remove("auto")

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
Expand All @@ -347,7 +401,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):

# Create query tensor
q = torch.rand(
batch_size, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
batch_size * s_qo,
num_qo_heads,
head_dim_qk,
device=device,
dtype=q_init_dtype,
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
Expand Down Expand Up @@ -451,8 +509,14 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
)

ragged_q = (
torch.arange(0, batch_size + 1, device=device) * (num_qo_heads * head_dim_qk)
torch.arange(0, batch_size + 1, device=device)
* (s_qo * num_qo_heads * head_dim_qk)
).long() # For cuDNN
speculative_mask = (
generate_speculative_causal_mask(batch_size, s_qo, device)
if speculative_decode
else None
)

scale = float(1.0 / (head_dim_qk**0.5))
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
Expand Down Expand Up @@ -530,10 +594,11 @@ def run_backend_wrapper(
block_tables,
actual_seq_lens_kv,
ragged_q,
speculative_mask,
):
if backend in ["fa2", "fa2_tc", "auto", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
q, kv_cache, k_scale=k_scale, v_scale=v_scale, q_len_per_req=s_qo
)
Comment on lines 600 to 602
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The BatchDecodeWithPagedKVCacheWrapper.run method (as seen in flashinfer/decode.py) does not currently accept a mask parameter. Consequently, when speculative_decode is enabled, the auto and trtllm-gen backends (which use this wrapper) are executing unmasked attention. This is inconsistent with the trtllm-native path (line 626) which correctly applies the causal mask. This inconsistency will lead to misleading performance results and incorrect outputs if refcheck were enabled for these paths.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. Will disallow backend='auto' for speculative decoding.

elif backend == "cudnn":
return flashinfer.decode.cudnn_batch_decode_with_kv_cache(
Expand All @@ -559,6 +624,10 @@ def run_backend_wrapper(
max_seq_len=s_kv,
bmm1_scale=scale if k_scale is None else k_scale * scale,
bmm2_scale=1.0 if v_scale is None else v_scale,
kv_layout="HND",
backend="auto",
q_len_per_req=s_qo,
mask=speculative_mask,
)
else:
print(f"[ERROR] Backend {backend} not supported")
Expand All @@ -581,6 +650,7 @@ def run_backend_wrapper(
block_tables,
actual_seq_lens_kv,
ragged_q,
speculative_mask,
)
.detach()
.clone()
Expand All @@ -607,6 +677,7 @@ def run_backend_wrapper(
block_tables,
actual_seq_lens_kv,
ragged_q,
speculative_mask,
),
)

Expand Down Expand Up @@ -643,7 +714,7 @@ def run_backend_wrapper(
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu")
actual_seq_lens_q_flat = torch.ones_like(actual_seq_lens_kv_flat)
actual_seq_lens_q_flat = torch.full_like(actual_seq_lens_kv_flat, s_qo)
tflops = attention_tflops_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
Expand Down
Loading