diff --git a/benchmarks/cute/benchmark_block_sparsity.py b/benchmarks/cute/benchmark_block_sparsity.py new file mode 100644 index 00000000000..74f220e8795 --- /dev/null +++ b/benchmarks/cute/benchmark_block_sparsity.py @@ -0,0 +1,363 @@ +""" +Comparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation. +""" + +import torch +from dataclasses import dataclass +from typing import Callable, Optional, List +from tabulate import tabulate +from tqdm import tqdm +import itertools + +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.testing import benchmark as cute_benchmark +import cutlass.cute as cute +from flash_attn.cute.compute_block_sparsity import BlockSparsityKernel +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.mask_definitions import ( + get_mask_pair, + random_doc_id_tensor, + flex_document_mask, + cute_document_mask, +) + +from torch.nn.attention.flex_attention import create_block_mask +from triton.testing import do_bench + +# Configure torch.compile cache to prevent memory buildup +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + batch_size: int + num_heads: int + seqlen_q: int + seqlen_k: int + mask_name: str + tile_m: int = 128 + tile_n: int = 128 + use_fast_sampling: bool = False + aux_tensors_cute: Optional[list] = None + + +@dataclass(frozen=True) +class BenchmarkResult: + """Result of a single benchmark run.""" + + config: BenchmarkConfig + cute_time_ms: Optional[float] + pytorch_time_ms: Optional[float] + error_message: Optional[str] = None + + +def benchmark_pytorch_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark PyTorch block mask creation (compiled). + Returns: creation_time_ms + """ + device = "cuda" + + try: + cbm = torch.compile(create_block_mask) + + def run_benchmark(): + return cbm( + mask_fn, + config.batch_size, + config.num_heads, + config.seqlen_q, + config.seqlen_k, + device=device, + ) + + creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100) + + return creation_time_ms + + except Exception as e: + print(f"PyTorch benchmark failed ({config.mask_name}): {e}") + import traceback + traceback.print_exc() + return None + + +def benchmark_cute_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark CuTe block sparsity kernel. + Returns: creation_time_ms + """ + device = "cuda" + + try: + num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m + num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + mask_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + full_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + + # Convert to CuTe tensors + mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + # Create kernel + use_aux = config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0 + kernel = BlockSparsityKernel( + mask_mod=mask_fn, + tile_mn=(config.tile_m, config.tile_n), + compute_full_blocks=True, + use_aux_tensors=use_aux, + use_fast_sampling=config.use_fast_sampling, + ) + + # Compile kernel + compiled_kernel = cute.compile( + kernel, + blocksparse_tensors, + config.seqlen_q, + config.seqlen_k, + config.aux_tensors_cute, + ) + + def generate_tensors(): + from cutlass.cute.testing import JitArguments + + return JitArguments( + blocksparse_tensors, config.seqlen_q, config.seqlen_k, config.aux_tensors_cute + ) + + creation_time_us = cute_benchmark( + compiled_kernel, + workspace_generator=generate_tensors, + warmup_iterations=10, + iterations=100, + ) + + torch.cuda.synchronize(device) + creation_time_ms = creation_time_us / 1000.0 + + return creation_time_ms + + except Exception as e: + print(f"CuTe benchmark failed: {e}") + return None + + +def run_benchmark( + config: BenchmarkConfig, + pytorch_mask_fn: Callable, + cute_mask_fn: Callable, +) -> BenchmarkResult: + """Run benchmarks for both implementations.""" + + print( + f"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, " + f"M={config.seqlen_q}, N={config.seqlen_k}" + ) + + # Benchmark PyTorch + pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn) + + # Benchmark CuTe + cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn) + + return BenchmarkResult( + config=config, + cute_time_ms=cute_time, + pytorch_time_ms=pytorch_time, + ) + + +def generate_configs( + batch_sizes: List[int], + num_heads: List[int], + seqlens: List[int], + mask_names: List[str], +) -> List[BenchmarkConfig]: + """Generate all benchmark configurations.""" + configs = [] + for B, H, S, mask_name in itertools.product(batch_sizes, num_heads, seqlens, mask_names): + configs.append( + BenchmarkConfig( + batch_size=B, + num_heads=H, + seqlen_q=S, + seqlen_k=S, + mask_name=mask_name, + ) + ) + return configs + + +def print_results(results: List[BenchmarkResult]): + successful_results = [ + r for r in results if r.cute_time_ms is not None and r.pytorch_time_ms is not None + ] + + if not successful_results: + print("No successful benchmark results to display") + return + + headers = ["B", "H", "M", "N", "Mask Type", "CuTe Time (ms)", "PyTorch Time (ms)", "Speedup"] + + rows = [] + for result in successful_results: + speedup = result.pytorch_time_ms / result.cute_time_ms if result.cute_time_ms > 0 else 0 + + rows.append( + [ + result.config.batch_size, + result.config.num_heads, + result.config.seqlen_q, + result.config.seqlen_k, + result.config.mask_name, + f"{result.cute_time_ms:.4f}", + f"{result.pytorch_time_ms:.4f}", + f"{speedup:.2f}x", + ] + ) + + # Sort by batch, head, seqlen, then mask type + rows.sort(key=lambda x: (x[0], x[1], x[2], x[4])) + + print("\n" + "=" * 100) + print("CuTe DSL vs PyTorch Block Sparsity Benchmark Results") + print("=" * 100) + print(tabulate(rows, headers=headers, tablefmt="github")) + print("=" * 100) + + +def main(): + """Run the comparative benchmark.""" + + # Configuration + batch_sizes = [1, 4, 8] + num_heads = [8, 16] + seqlens = [1024, 2048, 4096, 8192] + mask_names = [ + "causal", + "sliding_window", + "prefix_lm", + "dilated_sliding_window", + "document", + ] + + device = "cuda" + max_seqlen = max(seqlens) + max_batch = max(batch_sizes) + max_heads = max(num_heads) + + # Create document IDs using the helper from mask_definitions + doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device) + doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + + # Generate base configurations + base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names) + + # Update configs with aux tensors for document masking + configs = [] + for config in base_configs: + if config.mask_name == "document": + # Add aux tensors for document masking + configs.append( + BenchmarkConfig( + batch_size=config.batch_size, + num_heads=config.num_heads, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + mask_name=config.mask_name, + tile_m=config.tile_m, + tile_n=config.tile_n, + use_fast_sampling=False, + aux_tensors_cute=[doc_ids_cute], + ) + ) + else: + configs.append(config) + + # Run benchmarks + results = [] + print(f"Running {len(configs)} benchmark configurations...") + for config in tqdm(configs, desc="Benchmarking"): + try: + # Get mask pair from mask_definitions + mask_kwargs = {} + if config.mask_name == "sliding_window": + mask_kwargs["window_size"] = 128 # Default window size + + cute_mask_fn, pytorch_mask_fn = get_mask_pair( + config.mask_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + **mask_kwargs, + ) + + # For document masking, create wrapper that captures doc_ids + if config.mask_name == "document": + # PyTorch wrapper + def pytorch_mask_fn(b, h, q, kv): + return flex_document_mask(b, h, q, kv, doc_ids) + # CuTe wrapper - reuse cute_document_mask with aux_tensors + cute_mask_fn = cute_document_mask + + result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn) + results.append(result) + + except Exception as e: + print(f"Failed to run config {config}: {e}") + results.append( + BenchmarkResult( + config=config, + cute_time_ms=None, + pytorch_time_ms=None, + error_message=str(e), + ) + ) + finally: + torch.cuda.empty_cache() + torch._dynamo.reset() + + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/flash_attn/cute/benchmark_mask_mod.py b/benchmarks/cute/benchmark_mask_mod.py similarity index 98% rename from flash_attn/cute/benchmark_mask_mod.py rename to benchmarks/cute/benchmark_mask_mod.py index 88db8418abc..348d2ee485d 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/benchmarks/cute/benchmark_mask_mod.py @@ -14,8 +14,8 @@ import numpy as np import torch -from flash_fwd import FlashAttentionForwardSm90 -from mask_definitions import ( +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from flash_attn.cute.mask_definitions import ( get_mask_pair, random_doc_id_tensor, ) @@ -74,8 +74,8 @@ class BenchmarkConfig: mma_pv_is_rs: bool = True # Benchmark parameters - warmup_iters: int = 5 - benchmark_iters: int = 20 + warmup_iters: int = 10 + benchmark_iters: int = 25 verbose: bool = False seed: int = 42 @@ -649,16 +649,16 @@ def _print_results(self, results: Dict[str, Any]): dtype=torch.bfloat16, batch_size=B, # batch_size=1, - seqlen_q=16384 // B, + seqlen_q=8192, # seqlen_q=128, - seqlen_k=16384 // B, + seqlen_k=8192, # seqlen_k=192, use_varlen=False, - use_mask_mod=True, + use_mask_mod=False, mask_mod_name="causal", window_size=128, # Configurable window size for mask_mod use_learnable_sink=False, - causal=False, + causal=True, is_local=False, verbose=True, ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py new file mode 100644 index 00000000000..bec6fe5701f --- /dev/null +++ b/flash_attn/cute/compute_block_sparsity.py @@ -0,0 +1,403 @@ +from functools import partial +import math +import operator +from typing import Callable, Optional, Tuple, Type + +import cuda.bindings.driver as cuda +import cutlass +from cutlass import Boolean, Constexpr, Float32, Int32, Int8, const_expr +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import torch + +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar + + +class BlockSparsityKernel: + """Block sparsity kernel for FlexAttention. + + This kernel computes `mask_mod` for every token of each block + to determine if an n block is full, masked, or neither. + + Writes block counts and indices to a BlockSparseTensors object. + + When use_fast_sampling=True, uses 5-point sampling (4 corners + center) + which is much faster but only suitable for masks where this is sufficient. + """ + + def __init__( + self, + mask_mod: Callable, + tile_mn: Tuple[int, int], + compute_full_blocks: bool = True, + use_aux_tensors: bool = False, + use_fast_sampling: bool = False, + ): + self.mask_mod = mask_mod + self.tile_mn = tile_mn + self.compute_full_blocks = compute_full_blocks + self.use_aux_tensors = use_aux_tensors + self.use_fast_sampling = use_fast_sampling + + @cute.jit + def __call__( + self, + blocksparse_tensors: BlockSparseTensors, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + self.seqlen_q = seqlen_q + self.seqlen_k = seqlen_k + + if const_expr(self.compute_full_blocks): + assert self.full_cnt is not None and self.full_idx is not None, ( + "full block tensors must be provided when computing full blocks" + ) + + batch_size, num_heads, num_m_blocks, num_n_blocks = list(self.mask_idx.shape) + grid = [num_m_blocks, num_heads, batch_size] + + # Fast sampling uses only 5 threads (4 corners + center), full sampling uses 1 thread per row + if const_expr(self.use_fast_sampling): + num_threads = 5 + self.num_warps = 1 + else: + num_threads = self.tile_mn[0] + self.num_warps = (num_threads + 32 - 1) // 32 + + self.kernel( + self.mask_cnt, + self.mask_idx, + self.full_cnt, + self.full_idx, + num_n_blocks, + seqlen_q, + seqlen_k, + aux_tensors, + ).launch(grid=grid, block=[num_threads, 1, 1]) + + @cute.kernel + def kernel( + self, + mask_cnt: cute.Tensor, + mask_idx: cute.Tensor, + full_cnt: cute.Tensor, + full_idx: cute.Tensor, + num_n_blocks: Int32, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + # Store seqlens as instance variables for use in the kernel + self.seqlen_q = seqlen_q + self.seqlen_k = seqlen_k + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + m_block, head_idx, batch_idx = cute.arch.block_idx() + + ssa = partial(scalar_to_ssa, dtype=Int32) + + @cute.struct + class SharedStorage: + reduction_buffer_smem: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024 + ] + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage, 16) + + reduction_buffer = storage.reduction_buffer_smem.get_tensor( + cute.make_layout((self.num_warps, 2)) + ) + + num_mask_blocks = Int32(0) + num_full_blocks = Int32(0) + + for n_block in cutlass.range(num_n_blocks, unroll_full=True): + m_base = m_block * self.tile_mn[0] + n_base = n_block * self.tile_mn[1] + + if const_expr(self.use_fast_sampling): + # Fast path: 5-point sampling (4 corners + center) + # Out-of-bounds indices are treated as masked (False) + thread_result = Boolean(False) + thread_is_valid = Boolean(False) + q_idx = Int32(0) + kv_idx = Int32(0) + + if tidx == 0: + # Top-left corner (0, 0) + q_idx = m_base + kv_idx = n_base + elif tidx == 1: + # Top-right corner + q_idx = m_base + kv_idx = n_base + self.tile_mn[1] - 1 + elif tidx == 2: + # Bottom-left corner + q_idx = m_base + self.tile_mn[0] - 1 + kv_idx = n_base + elif tidx == 3: + # Bottom-right corner + q_idx = m_base + self.tile_mn[0] - 1 + kv_idx = n_base + self.tile_mn[1] - 1 + elif tidx == 4: + # Center point + q_idx = m_base + self.tile_mn[0] // 2 + kv_idx = n_base + self.tile_mn[1] // 2 + + # Check bounds and determine if this thread has a valid index pair + if q_idx < self.seqlen_q and kv_idx < self.seqlen_k: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + kv_idx_ssa = ssa(kv_idx) + thread_result = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), ssa(head_idx), q_idx_ssa, kv_idx_ssa, aux_tensors + ) + ) + else: + thread_is_valid = Boolean(False) + + # Use vote_any_sync to see if any valid thread found unmasked or masked + # Only count results from threads that checked valid indices + has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) + has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) + + else: + # Full path: check all elements in the block + # Track if this thread's row has any masked or unmasked elements + thread_has_unmasked = Boolean(False) + thread_has_masked = Boolean(False) + thread_is_valid = Boolean(False) + + # Each thread handles 1 row + q_idx = m_base + tidx + kv_idx = Int32(0) + if tidx < self.tile_mn[0] and q_idx < self.seqlen_q: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + + # Loop over all columns in this row + for c in cutlass.range(self.tile_mn[1], unroll_full=True): + kv_idx = n_base + c + kv_idx_ssa = ssa(kv_idx) + + # Only check elements within valid sequence bounds + if kv_idx < self.seqlen_k: + # Direct scalar call + mask_val = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + q_idx_ssa, + kv_idx_ssa, + aux_tensors, + ) + ) + + # Update tracking flags + if mask_val: + thread_has_unmasked = Boolean(True) + else: + thread_has_masked = Boolean(True) + + # Block-level reduction to combine results across all threads + # Only count votes from threads that checked valid indices + warp_has_unmasked_mask = cute.arch.vote_any_sync( + thread_has_unmasked & thread_is_valid + ) + warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) + + # lane 0 writes the ballot mask to shared memory + lane_id = tidx % 32 + if lane_id == 0: + # Store as Int8 + reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) + reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) + + cute.arch.sync_threads() + + # Thread 0 ORs all warp results together + has_unmasked = Boolean(False) + has_masked = Boolean(False) + if tidx == 0: + for w in cutlass.range(self.num_warps): + if reduction_buffer[w, 0]: + has_unmasked = Boolean(True) + if reduction_buffer[w, 1]: + has_masked = Boolean(True) + + # Only thread 0 updates the output arrays (common to both paths) + if tidx == 0: + # Block classification based on what we found: + # - If has_masked and has_unmasked: partial block (needs masking) + # - If only has_unmasked: full block (no masking needed) + # - If only has_masked: skip this block entirely + is_partial = Boolean(has_masked and has_unmasked) + is_full = Boolean(has_unmasked and (not has_masked)) + + if is_partial: + mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block + num_mask_blocks += 1 + elif is_full and const_expr(self.compute_full_blocks): + full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block + num_full_blocks += 1 + + # Only thread 0 writes back the counts + if tidx == 0: + mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[batch_idx, head_idx, m_block] = num_full_blocks + + +def compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + mask_mod: Callable, + aux_tensors: Optional[list], # list[cute.Tensor] + device, + compute_full_blocks: bool = True, + use_fast_sampling: bool = False, +) -> Tuple[BlockSparseTensors, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Computes block sparsity for a given `mask_mod`. + + Args: + tile_m: The tile size for the m dimension. + tile_n: The tile size for the n dimension. + batch_size: The batch size. + num_heads: The number of heads. + seqlen_q: The sequence length for the query. + seqlen_k: The sequence length for the key. + mask_mod: The `mask_mod` callable to use. + aux_tensors: A list of auxiliary tensors. + device: The device to use. + compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. + use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. + + Returns: + A tuple of `BlockSparseTensors` and the underlying torch tensors. + """ + num_m_blocks = (seqlen_q + tile_m - 1) // tile_m + num_n_blocks = (seqlen_k + tile_n - 1) // tile_n + + mask_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + full_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + + # Convert to cute tensors + mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + mask_mod_hash = hash_callable(mask_mod) + + compile_key = ( + tile_m, + tile_n, + mask_mod_hash, + compute_full_blocks, + aux_tensors is not None, + use_fast_sampling, + ) + if compile_key not in compute_block_sparsity.compile_cache: + kernel = BlockSparsityKernel( + mask_mod, + tile_mn=(tile_m, tile_n), + compute_full_blocks=True, + use_aux_tensors=aux_tensors is not None, + use_fast_sampling=use_fast_sampling, + ) + + compute_block_sparsity.compile_cache[compile_key] = cute.compile( + kernel, + blocksparse_tensors, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + compute_block_sparsity.compile_cache[compile_key]( + blocksparse_tensors, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + # Return both the BlockSparseTensors (cute) and the underlying torch tensors + return blocksparse_tensors, (full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx) + + +compute_block_sparsity.compile_cache = {} + + +def run(): + """Test the BlockSparsityKernel with a simple causal mask.""" + + print("Testing BlockSparsityKernel...") + + # Configuration + batch_size = 2 + num_heads = 2 + seqlen_q = 16384 + seqlen_k = 16384 + tile_m, tile_n = 128, 128 # Use very small tiles for initial testing + + # Define a simple causal mask function + @cute.jit + def causal_mask(batch_idx, head_idx, q_idx, kv_idx, aux_tensors): + """Simple causal mask: only attend to positions <= current position.""" + return q_idx >= kv_idx + + try: + compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + causal_mask, + None, + device="cuda", + ) + print("Kernel execution completed!") + except Exception as e: + print(f"Kernel execution failed: {e}") + + +if __name__ == "__main__": + run() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ce32f567e97..4989067b8c1 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -106,6 +106,8 @@ def _flash_attn_fwd( Args: ... score_mod: A callable that takes the attention scores and applies a modification. + mask_mod: A callable that takes token position information and selectively masks + block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index 0bb0d56751a..bbf2d212c0c 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -153,6 +153,54 @@ def cute_mini_causal_mask( return m_mod >= n_mod +@cute.jit +def cute_prefix_lm_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size_ssa = utils.scalar_to_ssa(512, cutlass.Int32) + both_in_prefix = (m_idx < prefix_size_ssa) & (n_idx < prefix_size_ssa) + causal_part = m_idx >= n_idx + return both_in_prefix | causal_part + + +def flex_prefix_lm_mask(b, h, q_idx, kv_idx): + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size = 512 + both_in_prefix = (q_idx < prefix_size) & (kv_idx < prefix_size) + causal_part = q_idx >= kv_idx + return both_in_prefix | causal_part + + +@cute.jit +def cute_dilated_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + """Dilated sliding window: every other position in a 256-position window.""" + window_size_ssa = utils.scalar_to_ssa(256, cutlass.Int32) + dilation_ssa = utils.scalar_to_ssa(2, cutlass.Int32) + in_window = (m_idx >= n_idx) & (m_idx - n_idx < window_size_ssa) + dilated = ((m_idx - n_idx) % dilation_ssa) == utils.scalar_to_ssa(0, cutlass.Int32) + return in_window & dilated + + +def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): + """Dilated sliding window: every other position in a 256-position window.""" + window_size = 256 + dilation = 2 + in_window = (q_idx >= kv_idx) & (q_idx - kv_idx < window_size) + dilated = ((q_idx - kv_idx) % dilation) == 0 + return in_window & dilated + + def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) for b in range(batch): @@ -175,6 +223,8 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): STATIC_MASKS = { "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), + "prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask), + "dilated_sliding_window": (cute_dilated_sliding_window_mask, flex_dilated_sliding_window_mask), "document": (cute_document_mask, flex_document_mask), } diff --git a/tests/cute/test_block_sparsity.py b/tests/cute/test_block_sparsity.py new file mode 100644 index 00000000000..d1ac5318004 --- /dev/null +++ b/tests/cute/test_block_sparsity.py @@ -0,0 +1,422 @@ +"""Tests for block sparsity computation in flash attention.""" + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask + +from flash_attn.cute.mask_definitions import get_mask_pair +from flash_attn.cute.compute_block_sparsity import compute_block_sparsity + + +def _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=None, + aux_tensors=None, + use_fast_sampling=False, +): + """Call compute_block_sparsity and return torch tensors.""" + cute_mask, _ = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + blocksparse_tensors, torch_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + mask_mod=cute_mask, + aux_tensors=aux_tensors, + device="cuda", + use_fast_sampling=use_fast_sampling, + ) + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = torch_tensors + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + + +def _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, +): + """Compare block sparsity against reference. Returns (all_match, error_msg).""" + if not isinstance(mask_block_cnt, torch.Tensor): + return False, f"mask_block_cnt is not a tensor: {type(mask_block_cnt)}" + + n_blocks_q = mask_block_cnt.shape[2] + mask_cnt_match = torch.all(mask_block_cnt == mask_block_cnt_ref).item() + full_cnt_match = torch.all(full_block_cnt == full_block_cnt_ref).item() + + if not mask_cnt_match or not full_cnt_match: + error_msg = [] + if not mask_cnt_match: + error_msg.append("Mask counts mismatch") + diff = (mask_block_cnt != mask_block_cnt_ref).nonzero(as_tuple=False) + if len(diff) > 0: + b, h, m = diff[0].tolist() + error_msg.append( + f" First mismatch at [{b},{h},{m}]: " + f"got {mask_block_cnt[b, h, m].item()}, " + f"expected {mask_block_cnt_ref[b, h, m].item()}" + ) + if not full_cnt_match: + error_msg.append("Full counts mismatch") + diff = (full_block_cnt != full_block_cnt_ref).nonzero(as_tuple=False) + if len(diff) > 0: + b, h, m = diff[0].tolist() + error_msg.append( + f" First mismatch at [{b},{h},{m}]: " + f"got {full_block_cnt[b, h, m].item()}, " + f"expected {full_block_cnt_ref[b, h, m].item()}" + ) + return False, "\n".join(error_msg) + + # Compare indices + for b in range(batch_size): + for h in range(nheads): + for m in range(n_blocks_q): + num_mask = mask_block_cnt[b, h, m].item() + num_full = full_block_cnt[b, h, m].item() + + if num_mask > 0: + mask_indices = mask_block_idx[b, h, m, :num_mask].sort()[0] + mask_indices_ref = mask_block_idx_ref[b, h, m, :num_mask].sort()[0] + if not (mask_indices == mask_indices_ref).all(): + return False, f"Mask indices mismatch at [{b},{h},{m}]" + + if num_full > 0: + full_indices = full_block_idx[b, h, m, :num_full].sort()[0] + full_indices_ref = full_block_idx_ref[b, h, m, :num_full].sort()[0] + if not (full_indices == full_indices_ref).all(): + return False, f"Full indices mismatch at [{b},{h},{m}]" + + return True, "" + + +# Test configurations +SEQLEN_PAIRS = [ + # Small aligned + (64, 64), + (128, 128), + (256, 256), + (512, 512), + # Rectangular + (128, 256), + (256, 128), + (512, 256), + (256, 512), + # Large aligned + (1024, 1024), + (2048, 2048), + (4096, 4096), + # Large unaligned + (1000, 1000), + (2000, 2000), + (4000, 4000), + # Edge cases with unaligned seqlens + (113, 203), + (127, 127), + (129, 129), + (255, 255), + (257, 257), + (1023, 1023), + (1025, 1025), + (2047, 2047), + (2049, 2049), +] +TILE_SIZES = [ + # Standard powers of 2 + (32, 32), + (64, 64), + (128, 128), + (256, 256), + # Rectangular + (32, 64), + (64, 32), + (64, 128), + (128, 64), + (128, 256), + (256, 128), + # Unusual sizes + (40, 40), + (48, 48), + (96, 96), + (112, 112), + (32, 128), + (128, 32), + (40, 96), + (96, 40), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize("tile_m,tile_n", TILE_SIZES) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal"]) +def test_fixed_length_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name +): + """Test fixed-length masks.""" + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize( + "mask_name,window_size", + [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], +) +def test_parameterized_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name, window_size +): + """Test parameterized masks.""" + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") + + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=window_size, + ) + ) + + _, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k,tile_m,tile_n", + [ + (1, 1, 64, 64), + (63, 63, 64, 64), + (65, 65, 64, 64), + (129, 129, 128, 128), + (100, 200, 64, 128), + ], +) +def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n): + """Test edge cases with unaligned dimensions.""" + batch_size, nheads = 1, 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + "causal", + ) + ) + + _, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) +def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): + """Test fast sampling mode (5-point sampling).""" + batch_size = 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + use_fast_sampling=True, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}"