-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[CuTe DSL] Block sparsity computation kernel #1983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tridao
merged 17 commits into
Dao-AILab:main
from
reubenconducts:rstern/compute-block-sparsity
Nov 12, 2025
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
c861597
begin block sparsity computation kernel
reubenconducts 68afe2a
block sparsity computation kernel and benchmark working
reubenconducts e384e9e
loop range_constexpr
reubenconducts 9a1dc8c
add fast kernel
reubenconducts b7f2886
merge fast and regular kernel
reubenconducts 4f60c99
use TensorSSA approach to mask mod
reubenconducts 8d087be
update with OOB check
reubenconducts 06a9900
tests and benchmarks for block sparsity working
reubenconducts 865576f
remove extraneous files
reubenconducts de0f3dd
Revert mask.py to previous state - removing unintended changes from b…
reubenconducts 1863997
remove flex attn test stub
reubenconducts 396dfec
add sleeps to benchmark
reubenconducts 65f2899
correct block sparsity benchmark to use torch.compile
reubenconducts 6065585
Restore missing mask definitions and fix benchmark window_size handling
reubenconducts 3873f20
move benchmarks into new directory
reubenconducts fdd8be0
compute_block_sparsity docstring
reubenconducts a0067fd
streamline compute block sparsity benchmark script
reubenconducts File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,363 @@ | ||
| """ | ||
| Comparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation. | ||
| """ | ||
|
|
||
| import torch | ||
| from dataclasses import dataclass | ||
| from typing import Callable, Optional, List | ||
| from tabulate import tabulate | ||
| from tqdm import tqdm | ||
| import itertools | ||
|
|
||
| from cutlass.cute.runtime import from_dlpack | ||
| from cutlass.cute.testing import benchmark as cute_benchmark | ||
| import cutlass.cute as cute | ||
| from flash_attn.cute.compute_block_sparsity import BlockSparsityKernel | ||
| from flash_attn.cute.block_sparsity import BlockSparseTensors | ||
| from flash_attn.cute.mask_definitions import ( | ||
| get_mask_pair, | ||
| random_doc_id_tensor, | ||
| flex_document_mask, | ||
| cute_document_mask, | ||
| ) | ||
|
|
||
| from torch.nn.attention.flex_attention import create_block_mask | ||
| from triton.testing import do_bench | ||
|
|
||
| # Configure torch.compile cache to prevent memory buildup | ||
| torch._dynamo.config.cache_size_limit = 1000 | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class BenchmarkConfig: | ||
| """Configuration for a benchmark run.""" | ||
|
|
||
| batch_size: int | ||
| num_heads: int | ||
| seqlen_q: int | ||
| seqlen_k: int | ||
| mask_name: str | ||
| tile_m: int = 128 | ||
| tile_n: int = 128 | ||
| use_fast_sampling: bool = False | ||
| aux_tensors_cute: Optional[list] = None | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class BenchmarkResult: | ||
| """Result of a single benchmark run.""" | ||
|
|
||
| config: BenchmarkConfig | ||
| cute_time_ms: Optional[float] | ||
| pytorch_time_ms: Optional[float] | ||
| error_message: Optional[str] = None | ||
|
|
||
|
|
||
| def benchmark_pytorch_block_sparsity( | ||
| config: BenchmarkConfig, | ||
| mask_fn: Callable, | ||
| ) -> Optional[float]: | ||
| """ | ||
| Benchmark PyTorch block mask creation (compiled). | ||
| Returns: creation_time_ms | ||
| """ | ||
| device = "cuda" | ||
|
|
||
| try: | ||
| cbm = torch.compile(create_block_mask) | ||
|
|
||
| def run_benchmark(): | ||
| return cbm( | ||
| mask_fn, | ||
| config.batch_size, | ||
| config.num_heads, | ||
| config.seqlen_q, | ||
| config.seqlen_k, | ||
| device=device, | ||
| ) | ||
|
|
||
| creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100) | ||
|
|
||
| return creation_time_ms | ||
|
|
||
| except Exception as e: | ||
| print(f"PyTorch benchmark failed ({config.mask_name}): {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return None | ||
|
|
||
|
|
||
| def benchmark_cute_block_sparsity( | ||
| config: BenchmarkConfig, | ||
| mask_fn: Callable, | ||
| ) -> Optional[float]: | ||
| """ | ||
| Benchmark CuTe block sparsity kernel. | ||
| Returns: creation_time_ms | ||
| """ | ||
| device = "cuda" | ||
|
|
||
| try: | ||
| num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m | ||
| num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n | ||
|
|
||
| mask_block_cnt = torch.zeros( | ||
| (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 | ||
| ) | ||
| mask_block_idx = torch.zeros( | ||
| (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), | ||
| device=device, | ||
| dtype=torch.int32, | ||
| ) | ||
| full_block_cnt = torch.zeros( | ||
| (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 | ||
| ) | ||
| full_block_idx = torch.zeros( | ||
| (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), | ||
| device=device, | ||
| dtype=torch.int32, | ||
| ) | ||
|
|
||
| # Convert to CuTe tensors | ||
| mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( | ||
| leading_dim=2 | ||
| ) | ||
| mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( | ||
| leading_dim=3 | ||
| ) | ||
| full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( | ||
| leading_dim=2 | ||
| ) | ||
| full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( | ||
| leading_dim=3 | ||
| ) | ||
|
|
||
| blocksparse_tensors = BlockSparseTensors( | ||
| mask_block_cnt=mask_cnt_cute, | ||
| mask_block_idx=mask_idx_cute, | ||
| full_block_cnt=full_cnt_cute, | ||
| full_block_idx=full_idx_cute, | ||
| ) | ||
|
|
||
| # Create kernel | ||
| use_aux = config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0 | ||
| kernel = BlockSparsityKernel( | ||
| mask_mod=mask_fn, | ||
| tile_mn=(config.tile_m, config.tile_n), | ||
| compute_full_blocks=True, | ||
| use_aux_tensors=use_aux, | ||
| use_fast_sampling=config.use_fast_sampling, | ||
| ) | ||
|
|
||
| # Compile kernel | ||
| compiled_kernel = cute.compile( | ||
| kernel, | ||
| blocksparse_tensors, | ||
| config.seqlen_q, | ||
| config.seqlen_k, | ||
| config.aux_tensors_cute, | ||
| ) | ||
|
|
||
| def generate_tensors(): | ||
| from cutlass.cute.testing import JitArguments | ||
|
|
||
| return JitArguments( | ||
| blocksparse_tensors, config.seqlen_q, config.seqlen_k, config.aux_tensors_cute | ||
| ) | ||
|
|
||
| creation_time_us = cute_benchmark( | ||
| compiled_kernel, | ||
| workspace_generator=generate_tensors, | ||
| warmup_iterations=10, | ||
| iterations=100, | ||
| ) | ||
|
|
||
| torch.cuda.synchronize(device) | ||
| creation_time_ms = creation_time_us / 1000.0 | ||
|
|
||
| return creation_time_ms | ||
|
|
||
| except Exception as e: | ||
| print(f"CuTe benchmark failed: {e}") | ||
| return None | ||
|
|
||
|
|
||
| def run_benchmark( | ||
| config: BenchmarkConfig, | ||
| pytorch_mask_fn: Callable, | ||
| cute_mask_fn: Callable, | ||
| ) -> BenchmarkResult: | ||
| """Run benchmarks for both implementations.""" | ||
|
|
||
| print( | ||
| f"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, " | ||
| f"M={config.seqlen_q}, N={config.seqlen_k}" | ||
| ) | ||
|
|
||
| # Benchmark PyTorch | ||
| pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn) | ||
|
|
||
| # Benchmark CuTe | ||
| cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn) | ||
|
|
||
| return BenchmarkResult( | ||
| config=config, | ||
| cute_time_ms=cute_time, | ||
| pytorch_time_ms=pytorch_time, | ||
| ) | ||
|
|
||
|
|
||
| def generate_configs( | ||
| batch_sizes: List[int], | ||
| num_heads: List[int], | ||
| seqlens: List[int], | ||
| mask_names: List[str], | ||
| ) -> List[BenchmarkConfig]: | ||
| """Generate all benchmark configurations.""" | ||
| configs = [] | ||
| for B, H, S, mask_name in itertools.product(batch_sizes, num_heads, seqlens, mask_names): | ||
| configs.append( | ||
| BenchmarkConfig( | ||
| batch_size=B, | ||
| num_heads=H, | ||
| seqlen_q=S, | ||
| seqlen_k=S, | ||
| mask_name=mask_name, | ||
| ) | ||
| ) | ||
| return configs | ||
|
|
||
|
|
||
| def print_results(results: List[BenchmarkResult]): | ||
| successful_results = [ | ||
| r for r in results if r.cute_time_ms is not None and r.pytorch_time_ms is not None | ||
| ] | ||
|
|
||
| if not successful_results: | ||
| print("No successful benchmark results to display") | ||
| return | ||
|
|
||
| headers = ["B", "H", "M", "N", "Mask Type", "CuTe Time (ms)", "PyTorch Time (ms)", "Speedup"] | ||
|
|
||
| rows = [] | ||
| for result in successful_results: | ||
| speedup = result.pytorch_time_ms / result.cute_time_ms if result.cute_time_ms > 0 else 0 | ||
|
|
||
| rows.append( | ||
| [ | ||
| result.config.batch_size, | ||
| result.config.num_heads, | ||
| result.config.seqlen_q, | ||
| result.config.seqlen_k, | ||
| result.config.mask_name, | ||
| f"{result.cute_time_ms:.4f}", | ||
| f"{result.pytorch_time_ms:.4f}", | ||
| f"{speedup:.2f}x", | ||
| ] | ||
| ) | ||
|
|
||
| # Sort by batch, head, seqlen, then mask type | ||
| rows.sort(key=lambda x: (x[0], x[1], x[2], x[4])) | ||
|
|
||
| print("\n" + "=" * 100) | ||
| print("CuTe DSL vs PyTorch Block Sparsity Benchmark Results") | ||
| print("=" * 100) | ||
| print(tabulate(rows, headers=headers, tablefmt="github")) | ||
| print("=" * 100) | ||
|
|
||
|
|
||
| def main(): | ||
| """Run the comparative benchmark.""" | ||
|
|
||
| # Configuration | ||
| batch_sizes = [1, 4, 8] | ||
| num_heads = [8, 16] | ||
| seqlens = [1024, 2048, 4096, 8192] | ||
| mask_names = [ | ||
| "causal", | ||
| "sliding_window", | ||
| "prefix_lm", | ||
| "dilated_sliding_window", | ||
| "document", | ||
| ] | ||
|
|
||
| device = "cuda" | ||
| max_seqlen = max(seqlens) | ||
| max_batch = max(batch_sizes) | ||
| max_heads = max(num_heads) | ||
|
|
||
| # Create document IDs using the helper from mask_definitions | ||
| doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device) | ||
| doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) | ||
|
|
||
| # Generate base configurations | ||
| base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names) | ||
|
|
||
| # Update configs with aux tensors for document masking | ||
| configs = [] | ||
| for config in base_configs: | ||
| if config.mask_name == "document": | ||
| # Add aux tensors for document masking | ||
| configs.append( | ||
| BenchmarkConfig( | ||
| batch_size=config.batch_size, | ||
| num_heads=config.num_heads, | ||
| seqlen_q=config.seqlen_q, | ||
| seqlen_k=config.seqlen_k, | ||
| mask_name=config.mask_name, | ||
| tile_m=config.tile_m, | ||
| tile_n=config.tile_n, | ||
| use_fast_sampling=False, | ||
| aux_tensors_cute=[doc_ids_cute], | ||
| ) | ||
| ) | ||
| else: | ||
| configs.append(config) | ||
|
|
||
| # Run benchmarks | ||
| results = [] | ||
| print(f"Running {len(configs)} benchmark configurations...") | ||
| for config in tqdm(configs, desc="Benchmarking"): | ||
| try: | ||
| # Get mask pair from mask_definitions | ||
| mask_kwargs = {} | ||
| if config.mask_name == "sliding_window": | ||
| mask_kwargs["window_size"] = 128 # Default window size | ||
|
|
||
| cute_mask_fn, pytorch_mask_fn = get_mask_pair( | ||
| config.mask_name, | ||
| seqlen_q=config.seqlen_q, | ||
| seqlen_k=config.seqlen_k, | ||
| **mask_kwargs, | ||
| ) | ||
|
|
||
| # For document masking, create wrapper that captures doc_ids | ||
| if config.mask_name == "document": | ||
| # PyTorch wrapper | ||
| def pytorch_mask_fn(b, h, q, kv): | ||
| return flex_document_mask(b, h, q, kv, doc_ids) | ||
| # CuTe wrapper - reuse cute_document_mask with aux_tensors | ||
| cute_mask_fn = cute_document_mask | ||
|
|
||
| result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn) | ||
| results.append(result) | ||
|
|
||
| except Exception as e: | ||
| print(f"Failed to run config {config}: {e}") | ||
| results.append( | ||
| BenchmarkResult( | ||
| config=config, | ||
| cute_time_ms=None, | ||
| pytorch_time_ms=None, | ||
| error_message=str(e), | ||
| ) | ||
| ) | ||
| finally: | ||
| torch.cuda.empty_cache() | ||
| torch._dynamo.reset() | ||
|
|
||
| print_results(results) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.