diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py new file mode 100644 index 00000000000..071b4e02a58 --- /dev/null +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -0,0 +1,714 @@ +""" +FlashAttention benchmarking script with Flex Attention-style +mask mod support and varlen sequences. +""" + +from dataclasses import dataclass +import math +from pickle import FALSE +from typing import Any, Dict, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import numpy as np +import torch + +from flash_fwd import FlashAttentionForwardSm90 +from mask_definitions import ( + MASK_FUNCTIONS, + random_doc_id_tensor, + create_cute_sliding_window_mask, + create_flex_sliding_window_mask, +) +from block_sparsity import compute_block_sparsity + + +@dataclass +class BenchmarkConfig: + """Benchmark configuration""" + + # Model parameters + headdim: int + headdim_v: int + nheads: int + nheads_kv: int + dtype: torch.dtype + + # Sequence parameters + batch_size: int = 2 + seqlen_q: int = 8192 + seqlen_k: int = 8192 + + # Varlen parameters + use_varlen: bool = False + min_seqlen_q: Optional[int] = None # If None, use seqlen_q // 2 + max_seqlen_q: Optional[int] = None # If None, use seqlen_q + min_seqlen_k: Optional[int] = None # If None, use seqlen_k // 2 + max_seqlen_k: Optional[int] = None # If None, use seqlen_k + + # Mask parameters + use_mask_mod: bool = True + mask_mod_name: str = "causal" + has_buffers: bool = mask_mod_name == "document" + + # Sliding window parameter (used when mask_mod_name == "sliding_window") + window_size: int = 128 + + # Attention parameters + causal: bool = False + is_local: bool = False + window_left: Optional[int] = 128 # For base Flash Attention local + window_right: Optional[int] = 0 # For base Flash Attention local + softcap: Optional[float] = None + use_learnable_sink: bool = False + + # Kernel configuration + tile_m: int = 128 + tile_n: int = 128 + num_stages: int = 2 + num_threads: int = 384 + intra_wg_overlap: bool = True + mma_pv_is_rs: bool = True + + # Benchmark parameters + warmup_iters: int = 5 + benchmark_iters: int = 20 + verbose: bool = False + seed: int = 42 + + +class FlashAttentionBenchmark: + def __init__(self, config: BenchmarkConfig): + self.config = config + + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + # Verify SM90 compute capability + compute_capability = torch.cuda.get_device_capability() + assert compute_capability >= (9, 0), ( + f"Requires SM90+, got SM{compute_capability[0]}{compute_capability[1]}" + ) + # causal overrides use_mask_mod + if config.causal: + config.use_mask_mod = False + + if config.use_mask_mod: + if config.mask_mod_name == "sliding_window": + # Use factory function for custom window size + self.mask_mod_cute = create_cute_sliding_window_mask(config.window_size) + self.mask_mod_flex = create_flex_sliding_window_mask(config.window_size) + else: + self.mask_mod_cute, self.mask_mod_flex = MASK_FUNCTIONS[config.mask_mod_name] + else: + self.mask_mod_cute = None + self.mask_mod_flex = None + + self._validate_config() + + def _validate_config(self): + config = self.config + + assert config.headdim <= 256, "headdim must be <= 256" + assert config.headdim_v <= 256, "headdim_v must be <= 256" + assert config.nheads % config.nheads_kv == 0, "nheads must be divisible by nheads_kv" + + alignment = 16 // config.dtype.itemsize + assert config.headdim % alignment == 0, f"headdim must be divisible by {alignment}" + assert config.headdim_v % alignment == 0, f"headdim_v must be divisible by {alignment}" + + # Validate is_local configuration + if config.is_local: + assert config.window_left is not None or config.window_right is not None, ( + "When is_local=True, at least one of window_left or window_right must be set" + ) + assert not config.use_mask_mod, ( + "Cannot use both is_local and use_mask_mod simultaneously" + ) + assert not config.causal, "Cannot use both is_local and causal simultaneously" + + # Validate mask_mod configuration + if config.use_mask_mod and config.mask_mod_name == "sliding_window": + assert config.window_size > 0, ( + "window_size must be positive when using sliding_window mask" + ) + + def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tuple[torch.Tensor, int]: + """Generate random sequence lengths and compute cumulative lengths.""" + seqlens = torch.randint( + min_len, max_len + 1, (self.config.batch_size,), dtype=torch.int32, device="cuda" + ) + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(seqlens, dtype=torch.int32, dim=0), + ] + ) + + total_tokens = cu_seqlens[-1].item() + return cu_seqlens, total_tokens + + def _create_tensors(self) -> Dict[str, torch.Tensor]: + config = self.config + device = "cuda" + + if config.use_varlen: + # Set defaults for varlen range + min_q = config.min_seqlen_q if config.min_seqlen_q is not None else config.seqlen_q // 2 + max_q = config.max_seqlen_q if config.max_seqlen_q is not None else config.seqlen_q + min_k = config.min_seqlen_k if config.min_seqlen_k is not None else config.seqlen_k // 2 + max_k = config.max_seqlen_k if config.max_seqlen_k is not None else config.seqlen_k + + # Generate cu_seqlens + cu_seqlens_q, total_q = self._generate_varlen_seqlens(min_q, max_q) + cu_seqlens_k, total_k = self._generate_varlen_seqlens(min_k, max_k) + + # Varlen shape: (total_tokens, nheads, headdim) + q = torch.randn( + total_q, config.nheads, config.headdim, dtype=config.dtype, device=device + ) + k = torch.randn( + total_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device + ) + v = torch.randn( + total_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device + ) + out = torch.empty( + total_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device + ) + lse = torch.empty(config.nheads, total_q, dtype=torch.float32, device=device) + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + "cu_seqlens_q": cu_seqlens_q.contiguous(), + "cu_seqlens_k": cu_seqlens_k.contiguous(), + } + + if config.verbose: + print(f"Varlen: total_q={total_q}, total_k={total_k}") + print(f"Q seqlens: {cu_seqlens_q[1:] - cu_seqlens_q[:-1]}") + print(f"K seqlens: {cu_seqlens_k[1:] - cu_seqlens_k[:-1]}") + else: + # Standard shape: (batch, seqlen, nheads, headdim) + q = torch.randn( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim, + dtype=config.dtype, + device=device, + ) + k = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim, + dtype=config.dtype, + device=device, + ) + v = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + out = torch.empty( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + lse = torch.empty( + config.batch_size, + config.nheads, + config.seqlen_q, + dtype=torch.float32, + device=device, + ) + + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + if config.use_learnable_sink: + learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) + + tensors["learnable_sink"] = learnable_sink.contiguous() + + # Compute block sparsity when using mask_mod + if config.use_mask_mod: + if config.mask_mod_name == "document": + doc_id = random_doc_id_tensor( + config.batch_size, config.nheads, config.seqlen_q, device=device + ) + tensors["buffers"] = [doc_id.contiguous()] + full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( + config=self.config, + mask_mod_flex=self.mask_mod_flex, + device=device, + cu_seqlens_q=tensors.get("cu_seqlens_q"), + cu_seqlens_k=tensors.get("cu_seqlens_k"), + buffers=tensors.get("buffers"), + ) + + if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): + tensors["full_block_cnt"] = full_cnt.contiguous() + tensors["full_block_idx"] = full_idx.contiguous() + tensors["mask_block_cnt"] = mask_cnt.contiguous() + tensors["mask_block_idx"] = mask_idx.contiguous() + + if config.verbose: + total_full = full_cnt.sum().item() + total_partial = mask_cnt.sum().item() + + if config.use_varlen: + # Compute max possible blocks across all sequences + max_blocks = 0 + for i in range(config.batch_size): + seq_len_q = ( + tensors["cu_seqlens_q"][i + 1] - tensors["cu_seqlens_q"][i] + ).item() + seq_len_k = ( + tensors["cu_seqlens_k"][i + 1] - tensors["cu_seqlens_k"][i] + ).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + max_blocks += n_blocks_q * n_blocks_k * config.nheads + else: + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + max_blocks = n_blocks_k * n_blocks_q * config.nheads * config.batch_size + + skipped = max_blocks - total_full - total_partial + print( + f"Block stats: Full={total_full}, Partial={total_partial}, " + f"Skipped={skipped}/{max_blocks}" + ) + + return tensors + + def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]: + config = self.config + + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + cute_dtype = dtype_map[config.dtype] + + qhead_per_kvhead = config.nheads // config.nheads_kv + kernel = FlashAttentionForwardSm90( + cute_dtype, + config.headdim, + config.headdim_v, + qhead_per_kvhead, + is_causal=config.causal, + is_local=config.is_local, + pack_gqa=False, + tile_m=config.tile_m, + tile_n=config.tile_n, + num_stages=config.num_stages, + num_threads=config.num_threads, + intra_wg_overlap=config.intra_wg_overlap, + mma_pv_is_rs=config.mma_pv_is_rs, + mask_mod=self.mask_mod_cute, + Q_in_regs=False, + has_buffers=config.has_buffers, + ) + + softmax_scale = 1.0 / math.sqrt(config.headdim) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Convert tensors to cute + q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["q"].ndim - 1 + ) + k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["k"].ndim - 1 + ) + v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["v"].ndim - 1 + ) + out_cute = from_dlpack(tensors["out"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["out"].ndim - 1 + ) + lse_cute = from_dlpack(tensors["lse"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=tensors["lse"].ndim - 1 + ) + + # Varlen tensors + cu_seqlens_q_cute = ( + from_dlpack(tensors["cu_seqlens_q"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_q" in tensors + else None + ) + cu_seqlens_k_cute = ( + from_dlpack(tensors["cu_seqlens_k"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_k" in tensors + else None + ) + learnable_sink_cute = ( + from_dlpack(tensors["learnable_sink"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "learnable_sink" in tensors + else None + ) + + # Block sparsity tensors + full_block_cnt_cute = ( + from_dlpack(tensors["full_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if "full_block_cnt" in tensors + else None + ) + full_block_idx_cute = ( + from_dlpack(tensors["full_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if "full_block_idx" in tensors + else None + ) + mask_block_cnt_cute = ( + from_dlpack(tensors["mask_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if "mask_block_cnt" in tensors + else None + ) + mask_block_idx_cute = ( + from_dlpack(tensors["mask_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if "mask_block_idx" in tensors + else None + ) + + if "buffers" in tensors: + buffers_cute = [] + for i in range(len(tensors["buffers"])): + buf = from_dlpack(tensors["buffers"][i].detach(), assumed_align=4) + buffers_cute.append(buf.mark_layout_dynamic(leading_dim=2)) + + else: + buffers_cute = None + + # Window parameters for is_local + window_left_cute = ( + cutlass.Int32(config.window_left) if config.window_left is not None else None + ) + window_right_cute = ( + cutlass.Int32(config.window_right) if config.window_right is not None else None + ) + + compiled = cute.compile( + kernel, + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + learnable_sink_cute, # learnable_sink + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + buffers_cute, + # None, + ) + + args = ( + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, + None, + None, + window_left_cute, + window_right_cute, + learnable_sink_cute, + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + buffers_cute, + # None, + ) + + return compiled, args + + def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: + config = self.config + + # Estimate sparsity for known mask patterns + if config.is_local: + # Local attention with window_left and window_right + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 # +1 for current position + sparsity_ratio = min(1.0, total_window / config.seqlen_k) + elif config.use_mask_mod: + if config.mask_mod_name in ["identity", "identity_partial"]: + sparsity_ratio = 1.0 + elif config.mask_mod_name in ["causal", "block_causal"]: + sparsity_ratio = 0.5 + elif config.mask_mod_name == "sliding_window": + # Use configured window size + sparsity_ratio = min(1.0, config.window_size / config.seqlen_k) + elif config.mask_mod_name == "block_diagonal": + block_size = 64 + num_blocks = (config.seqlen_k + block_size - 1) // block_size + sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 + elif config.mask_mod_name == "document": + vals = tensors["buffers"][0] + val_mask = torch.ones_like(vals, dtype=torch.bool) + val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] + total = torch.where(val_mask, vals.square(), 0).sum() + sparsity_ratio = total / (config.seqlen_q * config.seqlen_k) + else: + sparsity_ratio = 1.0 + elif config.causal: + sparsity_ratio = 0.5 + else: + sparsity_ratio = 1.0 + + if config.use_varlen: + # Compute FLOPs per sequence and sum + total_flops = 0 + cu_q = tensors["cu_seqlens_q"] + cu_k = tensors["cu_seqlens_k"] + for i in range(config.batch_size): + seq_len_q = (cu_q[i + 1] - cu_q[i]).item() + seq_len_k = (cu_k[i + 1] - cu_k[i]).item() + + # Adjust sparsity for local attention in varlen case + if config.is_local: + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 + seq_sparsity = min(1.0, total_window / seq_len_k) + elif config.use_mask_mod and config.mask_mod_name == "sliding_window": + seq_sparsity = min(1.0, config.window_size / seq_len_k) + else: + seq_sparsity = sparsity_ratio + + num_cells = int(seq_len_q * seq_len_k * seq_sparsity) + + if config.headdim == config.headdim_v: + flops_this_seq = 4 * config.nheads * num_cells * config.headdim + else: + flops_this_seq = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + total_flops += flops_this_seq + return total_flops + else: + num_cells = int(config.seqlen_q * config.seqlen_k * sparsity_ratio) + if config.headdim == config.headdim_v: + flops_per_batch = 4 * config.nheads * num_cells * config.headdim + else: + flops_per_batch = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + return flops_per_batch * config.batch_size + + def benchmark(self) -> Dict[str, Any]: + config = self.config + + tensors = self._create_tensors() + compiled_kernel, args = self._compile_kernel(tensors) + + # Warmup + for _ in range(config.warmup_iters): + compiled_kernel(*args) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.benchmark_iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + compiled_kernel(*args) + end.record() + torch.cuda.synchronize() + + times.append(start.elapsed_time(end)) + + times_tensor = torch.tensor(times) + mean_time = times_tensor.mean().item() + std_time = times_tensor.std().item() if len(times) > 1 else 0.0 + + total_flops = self._calculate_flops(tensors) + tflops = total_flops / (mean_time * 1e-3) / 1e12 + + # Bandwidth calculation + bytes_per_element = config.dtype.itemsize + if config.use_varlen: + total_q = tensors["q"].shape[0] + total_k = tensors["k"].shape[0] + memory_accessed = ( + total_q * config.nheads * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim_v * bytes_per_element + + total_q * config.nheads * config.headdim_v * bytes_per_element + ) + else: + memory_accessed = ( + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim_v + * bytes_per_element + + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim_v + * bytes_per_element + ) + bandwidth_gbps = memory_accessed / (mean_time * 1e-3) / 1e9 + + results = { + "mean_time_ms": mean_time, + "std_time_ms": std_time, + "tflops": tflops, + "bandwidth_gbps": bandwidth_gbps, + } + + if config.verbose: + self._print_results(results) + + return results + + def _print_results(self, results: Dict[str, Any]): + config = self.config + + # Basic configuration + if config.use_varlen: + print( + f"Shape: B={config.batch_size} (varlen), HD={config.headdim}, " + f"NH={config.nheads}, NKV={config.nheads_kv}" + ) + else: + print( + f"Shape: B={config.batch_size}, Q={config.seqlen_q}, K={config.seqlen_k}, " + f"HD={config.headdim}, NH={config.nheads}, NKV={config.nheads_kv}" + ) + + # Attention pattern + attn_info = [] + if config.causal: + attn_info.append("causal") + if config.is_local: + window_info = f"local(L={config.window_left},R={config.window_right})" + attn_info.append(window_info) + if config.use_mask_mod: + if config.mask_mod_name == "sliding_window": + attn_info.append(f"mask_mod={config.mask_mod_name}(w={config.window_size})") + else: + attn_info.append(f"mask_mod={config.mask_mod_name}") + if config.use_varlen: + attn_info.append("varlen") + if attn_info: + print(f"Attention: {', '.join(attn_info)}") + + # Performance metrics + print(f"Time: {results['mean_time_ms']:.3f} ± {results['std_time_ms']:.3f} ms") + print(f"Throughput: {results['tflops']:.2f} TFLOPS") + print(f"Bandwidth: {results['bandwidth_gbps']:.1f} GB/s") + + +if __name__ == "__main__": + B = 2 + config = BenchmarkConfig( + headdim=128, + headdim_v=128, + nheads=16, + nheads_kv=16, + dtype=torch.bfloat16, + batch_size=B, + # batch_size=1, + seqlen_q=16384 // B, + # seqlen_q=128, + seqlen_k=16384 // B, + # seqlen_k=192, + use_varlen=False, + use_mask_mod=True, + mask_mod_name="identity", + window_size=128, # Configurable window size for mask_mod + use_learnable_sink=False, + causal=False, + is_local=False, + verbose=True, + ) + + # Example 2: Base Flash Attention Local + # config = BenchmarkConfig( + # headdim=64, + # headdim_v=64, + # nheads=64, + # nheads_kv=8, + # dtype=torch.bfloat16, + # batch_size=2, + # seqlen_q=8192, + # seqlen_k=8192, + # use_varlen=False, + # use_mask_mod=False, + # causal=False, + # is_local=True, + # window_left=128, # Left window size for base local attention + # window_right=0, # Right window size for base local attention + # verbose=True, + # ) + + benchmark = FlashAttentionBenchmark(config) + results = benchmark.benchmark() diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py new file mode 100644 index 00000000000..ce05cae1438 --- /dev/null +++ b/flash_attn/cute/block_sparsity.py @@ -0,0 +1,372 @@ +""" +Computes block-sparse attention masks for Flex Attention. + +This utility generates block sparsity patterns based on common attention masking +strategies (e.g., causal, sliding window). The resulting tensors define which +blocks are fully computed, which are partially computed (requiring a mask), and +which are skipped entirely. This is a temporary solution intended to be replaced +by a more robust preprocessing kernel in the future. +""" + +from typing import Tuple, Optional, Callable, List +import torch + +# placeholder +Config = type("Config", (), {}) + +def compute_block_sparsity( + config: Config, + mask_mod_flex: Optional[Callable], + device: str, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + buffers: Optional[List[torch.Tensor]] = None, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Computes block sparsity tensors from a given masking function. + + This function serves as the main entry point for generating block-sparse masks. + It dispatches to specialized handlers for variable-length and fixed-length + sequences. + + Args: + config: A configuration object containing model and tiling parameters. + mask_mod_flex: The mask function for generic flex attention patterns. + device: The device to create tensors on (e.g., 'cuda'). + cu_seqlens_q: Cumulative sequence lengths for Q (for varlen). + cu_seqlens_k: Cumulative sequence lengths for K (for varlen). + buffers: A list of auxiliary tensors, e.g., for document masking. + + Returns: + A tuple of four tensors: + - `full_block_cnt`: (batch, nheads, n_blocks_q) - Count of full n blocks per m block. + - `full_block_idx`: (batch, nheads, n_blocks_q, max_n_blocks) - Indices of full n blocks. + - `mask_block_cnt`: (batch, nheads, n_blocks_q) - Count of partial n blocks per m block. + - `mask_block_idx`: (batch, nheads, n_blocks_q, max_n_blocks) - Indices of partial n blocks. + Returns (None, None, None, None) if masking is disabled. + """ + if not config.use_mask_mod or mask_mod_flex is None: + return None, None, None, None + + if cu_seqlens_q is not None: + # Handle variable-length sequences + return _compute_varlen_sparsity(config, mask_mod_flex, device, cu_seqlens_q, cu_seqlens_k) + else: + # Handle fixed-length sequences + return _compute_sparsity(config, device, buffers) + +## --------------------------------------------------------------------------- +## Fixed-Length Sequence Kernels +## --------------------------------------------------------------------------- + +def _compute_sparsity( + config: Config, device: str, buffers: Optional[List[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes block sparsity for fixed-length sequences.""" + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + # Pre-allocate output tensors + full_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) + mask_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) + full_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) + mask_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) + + # --- Identity Mask --- + # All blocks are fully computed. + if config.mask_mod_name == "identity": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + full_block_cnt[:, :, q_block_idx] = n_blocks_k + full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks + + # --- Identity Partial Mask --- + # All blocks are partially computed (masked). + elif config.mask_mod_name == "identity_partial": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + mask_block_cnt[:, :, q_block_idx] = n_blocks_k + mask_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks + + # --- Block Causal Mask --- + elif config.mask_mod_name == "block_causal": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + causal_indices = k_blocks[k_blocks <= q_block_idx] + num_causal_indices = len(causal_indices) + if num_causal_indices > 0: + full_block_cnt[:, :, q_block_idx] = num_causal_indices + full_block_idx[:, :, q_block_idx, :num_causal_indices] = causal_indices + + # --- Causal and Sliding Window Masks --- + elif config.mask_mod_name in ["causal", "sliding_window"]: + q_block_indices = torch.arange(n_blocks_q, device=device) + k_block_indices = torch.arange(n_blocks_k, device=device) + + q_starts = q_block_indices * config.tile_m + q_ends = torch.minimum((q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device)) + k_starts = k_block_indices * config.tile_n + k_ends = torch.minimum((k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device)) + + # Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k) + q_starts, q_ends = q_starts.unsqueeze(1), q_ends.unsqueeze(1) + k_starts, k_ends = k_starts.unsqueeze(0), k_ends.unsqueeze(0) + + offset = config.seqlen_k - config.seqlen_q + + if config.mask_mod_name == "causal": + is_full = (k_ends - 1) <= (q_starts + offset) + # min(k_pos) <= max(q_pos) AND not is_full. + is_partial = (k_starts <= (q_ends - 1 + offset)) & ~is_full + + else: # sliding_window + window_size = getattr(config, 'window_size', 1024) + is_full = (k_ends - 1 <= q_starts + offset) & (k_starts >= q_ends - 1 + offset - (window_size - 1)) + # A block is EMPTY if no (q, k) pairs satisfy the constraint. + is_empty = (k_starts > q_ends - 1 + offset) | (k_ends - 1 < q_starts + offset - (window_size - 1)) + # A block is PARTIAL if it's not empty and not full. + is_partial = ~is_empty & ~is_full + + # Populate indices based on the computed block classifications + for q_block_idx in range(n_blocks_q): + full_indices = k_block_indices[is_full[q_block_idx]] + if len(full_indices) > 0: + full_block_cnt[:, :, q_block_idx] = len(full_indices) + full_block_idx[:, :, q_block_idx, :len(full_indices)] = full_indices + + partial_indices = k_block_indices[is_partial[q_block_idx]] + if len(partial_indices) > 0: + mask_block_cnt[:, :, q_block_idx] = len(partial_indices) + mask_block_idx[:, :, q_block_idx, :len(partial_indices)] = partial_indices + + elif config.mask_mod_name == "document": + raise NotImplementedError("Block sparsity for document masking not yet implemented") + + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + +## --------------------------------------------------------------------------- +## Variable-Length Sequence Kernels +## --------------------------------------------------------------------------- + +def _compute_varlen_sparsity( + config: Config, + mask_mod_flex: Callable, + device: str, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes block sparsity for variable-length sequences.""" + assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention" + assert cu_seqlens_q.shape[0] == config.batch_size + 1 + assert cu_seqlens_k.shape[0] == config.batch_size + 1 + + # In varlen, each sequence can have a different number of Q blocks. + # We pad up to the maximum number of Q blocks in the batch. + max_m_blocks = 0 + for seq_idx in range(config.batch_size): + seq_len_q = (cu_seqlens_q[seq_idx + 1] - cu_seqlens_q[seq_idx]).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + max_m_blocks = max(max_m_blocks, n_blocks_q) + + # The number of K blocks is determined by the total length of all sequences. + total_k_len = cu_seqlens_k[-1].item() + max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n + + # Pre-allocate padded output tensors + full_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) + mask_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) + full_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) + mask_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) + + # Process each sequence in the batch individually + for seq_idx in range(config.batch_size): + seq_start_q = cu_seqlens_q[seq_idx].item() + seq_end_q = cu_seqlens_q[seq_idx + 1].item() + seq_len_q = seq_end_q - seq_start_q + + seq_start_k = cu_seqlens_k[seq_idx].item() + seq_end_k = cu_seqlens_k[seq_idx + 1].item() + seq_len_k = seq_end_k - seq_start_k + + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + + # Global block indices are relative to the start of the entire batch tensor + first_m_block_global = seq_start_q // config.tile_m + first_n_block_global = seq_start_k // config.tile_n + + common_args = { + "full_block_cnt": full_block_cnt, "full_block_idx": full_block_idx, + "mask_block_cnt": mask_block_cnt, "mask_block_idx": mask_block_idx, + "seq_idx": seq_idx, "n_blocks_q": n_blocks_q, "n_blocks_k": n_blocks_k, + "seq_start_q": seq_start_q, "seq_end_q": seq_end_q, + "seq_start_k": seq_start_k, "seq_end_k": seq_end_k, + "first_n_block_global": first_n_block_global, + "tile_m": config.tile_m, "tile_n": config.tile_n, "device": device + } + + if config.mask_mod_name == "causal": + _compute_causal_varlen_blocks(**common_args) + elif config.mask_mod_name == "sliding_window": + window_size = getattr(config, 'window_size', 1024) + _compute_sliding_window_varlen_blocks(**common_args, window_size=window_size) + elif config.mask_mod_name == "identity": + _compute_identity_varlen_blocks( + full_block_cnt, full_block_idx, seq_idx, + n_blocks_q, n_blocks_k, first_n_block_global, device + ) + else: + # Generic case relies on sampling the user-provided mask function + _compute_generic_varlen_blocks( + **common_args, mask_mod_flex=mask_mod_flex, + seq_len_q=seq_len_q, seq_len_k=seq_len_k, + num_heads=config.nheads, nheads_kv=config.nheads_kv, + ) + + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + +def _classify_varlen_block( + m_local: int, n_local: int, seq_start_q: int, seq_end_q: int, + seq_start_k: int, seq_end_k: int, tile_m: int, tile_n: int, + is_full_fn: Callable, is_partial_fn: Callable +) -> Tuple[bool, bool]: + """Helper to classify a varlen block as full, partial, or empty.""" + m_start_global = seq_start_q + m_local * tile_m + m_end_global = min(seq_start_q + (m_local + 1) * tile_m, seq_end_q) + n_start_global = seq_start_k + n_local * tile_n + n_end_global = min(seq_start_k + (n_local + 1) * tile_n, seq_end_k) + + # Use sequence-local coordinates for the logical check + m_start_local = m_start_global - seq_start_q + m_end_local = m_end_global - seq_start_q + n_start_local = n_start_global - seq_start_k + n_end_local = n_end_global - seq_start_k + + is_full = is_full_fn(m_start_local, m_end_local, n_start_local, n_end_local) + is_partial = is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full + + # Any block that touches the sequence boundary is partial because it requires masking. + at_boundary = (m_end_global > seq_end_q) or (n_end_global > seq_end_k) + + return is_full and not at_boundary, is_partial or (is_full and at_boundary) + +def _compute_causal_varlen_blocks( + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, + seq_idx, n_blocks_q, n_blocks_k, + seq_start_q, seq_end_q, seq_start_k, seq_end_k, + first_n_block_global, tile_m, tile_n, device, **kwargs +): + """Computes causal block sparsity for a single varlen sequence.""" + is_full_fn = lambda m_start, m_end, n_start, n_end: (m_start >= n_end - 1) + is_partial_fn = lambda m_start, m_end, n_start, n_end: (m_end - 1 >= n_start) + + for m_local in range(n_blocks_q): + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + is_full, is_partial = _classify_varlen_block( + m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, + tile_m, tile_n, is_full_fn, is_partial_fn + ) + n_block_global = first_n_block_global + n_local + if is_full: + full_blocks.append(n_block_global) + elif is_partial: + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, :, m_local] = len(full_blocks) + full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + if partial_blocks: + mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + +def _compute_sliding_window_varlen_blocks( + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, + seq_idx, n_blocks_q, n_blocks_k, + seq_start_q, seq_end_q, seq_start_k, seq_end_k, + first_n_block_global, tile_m, tile_n, window_size, device, **kwargs +): + """Computes sliding window block sparsity for a single varlen sequence.""" + is_full_fn = lambda m_start, m_end, n_start, n_end: \ + (n_end - 1 <= m_start) and (n_start >= m_start - window_size + 1) + is_partial_fn = lambda m_start, m_end, n_start, n_end: \ + not ((n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1)) + + for m_local in range(n_blocks_q): + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + is_full, is_partial = _classify_varlen_block( + m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, + tile_m, tile_n, is_full_fn, is_partial_fn + ) + n_block_global = first_n_block_global + n_local + if is_full: + full_blocks.append(n_block_global) + elif is_partial: + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, :, m_local] = len(full_blocks) + full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + if partial_blocks: + mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + +def _compute_identity_varlen_blocks( + full_block_cnt, full_block_idx, seq_idx, n_blocks_q, + n_blocks_k, first_n_block_global, device, **kwargs +): + """Computes identity (all-attend) block sparsity for a single varlen sequence.""" + n_blocks_global = torch.arange( + first_n_block_global, first_n_block_global + n_blocks_k, + device=device, dtype=torch.int32 + ) + for m_local in range(n_blocks_q): + full_block_cnt[seq_idx, :, m_local] = n_blocks_k + full_block_idx[seq_idx, :, m_local, :n_blocks_k] = n_blocks_global + +def _compute_generic_varlen_blocks( + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, + mask_mod_flex, seq_idx, num_heads, n_blocks_q, n_blocks_k, + seq_len_q, seq_len_k, first_n_block_global, + tile_m, tile_n, nheads_kv, device, **kwargs +): + """Generic sampling-based block classification for a varlen sequence.""" + qhead_per_kvhead = num_heads // nheads_kv + + for h_q in range(num_heads): + h_kv = h_q // qhead_per_kvhead + for m_local in range(n_blocks_q): + m_start_local = m_local * tile_m + m_end_local = min((m_local + 1) * tile_m, seq_len_q) + + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + n_start_local = n_local * tile_n + n_end_local = min((n_local + 1) * tile_n, seq_len_k) + + # Sample points within the block (corners and center) to classify it. + # Coordinates are sequence-local, as required by mask_mod_flex. + sample_positions = [ + (m_start_local, n_start_local), (m_start_local, n_end_local - 1), + (m_end_local - 1, n_start_local), (m_end_local - 1, n_end_local - 1), + ((m_start_local + m_end_local) // 2, (n_start_local + n_end_local) // 2), + ] + + unmasked_count = sum( + 1 for q_pos, k_pos in sample_positions + if mask_mod_flex(seq_idx, h_q, q_pos, k_pos, seq_len_q, seq_len_k) + ) + + n_block_global = first_n_block_global + n_local + if unmasked_count == len(sample_positions): # All samples unmasked -> full + full_blocks.append(n_block_global) + elif unmasked_count > 0: # Some unmasked -> partial + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks) + full_block_idx[seq_idx, h_q, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + if partial_blocks: + mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, h_q, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) \ No newline at end of file diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 92382ae8b42..4922a1534c9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -7,14 +7,14 @@ import math from types import SimpleNamespace -from typing import Type, Callable, Optional +from typing import Type, Callable, Optional, List from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Boolean, const_expr +from cutlass import Constexpr, Float32, Int32, const_expr, Boolean from cutlass.cute.nvgpu import cpasync, warp, warpgroup from cutlass.cute.arch import ProxyKind, SharedSpace import cutlass.utils as utils_basic @@ -54,7 +54,8 @@ def __init__( num_stages: int = 1, num_threads: int = 128, Q_in_regs: bool = False, - score_mod: cutlass.Constexpr | None = None, + score_mod: Optional[cutlass.Constexpr] = None, + mask_mod: Optional[cutlass.Constexpr] = None, has_buffers: bool = False, ): """Initializes the configuration for a flash attention kernel. @@ -73,6 +74,8 @@ def __init__( :param is_causal: is causal :param score_mod: A callable that takes the attention scores and applies a modification. Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any`` + :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, buffers) -> Boolean`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -94,8 +97,9 @@ def __init__( self.num_stages = num_stages self.Q_in_regs = Q_in_regs self.score_mod = score_mod + self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if cutlass.const_expr(has_buffers): + if const_expr(has_buffers): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 2 @@ -601,7 +605,7 @@ def __call__( softmax_scale = Float32(softmax_scale) fastdiv_mods = None - if cutlass.const_expr(buffers is not None): + if const_expr(buffers is not None): seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) @@ -938,7 +942,7 @@ def load_V_next(): # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) - if cutlass.const_expr(score_mod is not None): + if const_expr(score_mod is not None): self.apply_score_mod( mma_params.thr_mma_qk, batch_idx, @@ -984,10 +988,17 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase): arch = 90 - def __init__(self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, **kwargs): + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + **kwargs, + ): super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap self.mma_pv_is_rs = mma_pv_is_rs + def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( @@ -1107,19 +1118,26 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - buffers=None, + full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) + full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) + mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + buffers: Optional[list[cute.Tensor]] = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ + self._check_type( *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) + # Assume all strides are divisible by 128 bits except the last stride new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] @@ -1146,6 +1164,7 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 + self.use_block_sparsity = const_expr(mask_block_cnt is not None and full_block_cnt is not None) self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0) self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa @@ -1255,7 +1274,7 @@ def __call__( window_size_right = Int32(window_size_right) fastdiv_mods = None - if cutlass.const_expr(buffers is not None): + if const_expr(buffers is not None): seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) @@ -1281,6 +1300,10 @@ def __call__( window_size_left, window_size_right, learnable_sink, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1327,6 +1350,10 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + full_block_cnt: Optional[cute.Tensor], + full_block_idx: Optional[cute.Tensor], + mask_block_cnt: Optional[cute.Tensor], + mask_block_idx: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1342,7 +1369,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], - buffers=None, + buffers=Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1436,6 +1463,10 @@ def kernel( pipeline_k, pipeline_v, mbar_ptr_Q, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1474,6 +1505,10 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, buffers, fastdiv_mods, ) @@ -1493,6 +1528,10 @@ def load( pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, + full_block_cnt: Optional[cute.Tensor], + full_block_idx: Optional[cute.Tensor], + mask_block_cnt: Optional[cute.Tensor], + mask_block_idx: Optional[cute.Tensor], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -1527,44 +1566,175 @@ def load( load_V, _, _ = copy_utils.tma_get_copy_fn(tma_atom_V, 0, cute.make_layout(1), gV, sV) load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - # First iteration: load both Q & K with the same mbarrier - n_block = n_block_max - 1 - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 - ) - if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - load_K(src_idx=n_block, producer_state=kv_producer_state) - if const_expr(not self.intra_wg_overlap): - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 1 - i - 1 - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # First iteration: load both Q & K with the same mbarrier + n_block = n_block_max - 1 + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block, producer_state=kv_producer_state) + + if const_expr(not self.intra_wg_overlap): pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() - else: - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block_prev = n_block_max - i - 1 - n_block = n_block_prev - 1 - kv_producer_state_prev = kv_producer_state.clone() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + n_block = n_block_min + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) - n_block = n_block_min - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() + else: + # ========================================== + # Flex Attention blocksparsity + # ========================================== + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(not self.intra_wg_overlap): + if curr_mask_block_cnt > 0: + # First mask block - load with Q + n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1] + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block_mask, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_mask, producer_state=kv_producer_state) + kv_producer_state.advance() + + # Remaining mask blocks + for i in cutlass.range(1, curr_mask_block_cnt): + n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_mask, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_mask, producer_state=kv_producer_state) + kv_producer_state.advance() + + if curr_full_block_cnt > 0: + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + # must load Q if not loaded in mask loop + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full, producer_state=kv_producer_state) + kv_producer_state.advance() + for j in cutlass.range(1, curr_full_block_cnt): + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1 - j] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full, producer_state=kv_producer_state) + kv_producer_state.advance() + + else: + # ========================================== + # Overlap path + # ========================================== + + # Load Q with the first K block (whether mask or full) + n_block_first = -1 + if curr_mask_block_cnt > 0: + n_block_first = curr_mask_block_idx[curr_mask_block_cnt - 1] + elif curr_full_block_cnt > 0: + n_block_first = curr_full_block_idx[curr_full_block_cnt - 1] + + if n_block_first >= 0: + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + + if curr_mask_block_cnt > 0: + # Staggered loading for remaining mask blocks + for i in cutlass.range(1, curr_mask_block_cnt): + n_block_mask_prev = curr_mask_block_idx[curr_mask_block_cnt - i] + n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_mask, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev) + + # Handle transition from mask to full blocks + if curr_full_block_cnt > 0: + # Load first full block K, last mask block V + n_block_mask_last = curr_mask_block_idx[0] + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + else: + # No full blocks, just load last mask block V + n_block_mask_last = curr_mask_block_idx[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state) + kv_producer_state.advance() + + if curr_full_block_cnt > 0: + # Staggered loading for remaining full blocks ( + for j in cutlass.range(1, curr_full_block_cnt): + n_block_full_prev = curr_full_block_idx[curr_full_block_cnt - j] + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1 - j] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_full_prev, producer_state=kv_producer_state_prev) + + # Load last full block V + n_block_full_last = curr_full_block_idx[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full_last, producer_state=kv_producer_state) + kv_producer_state.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1601,7 +1771,11 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, - buffers=None, + full_block_cnt: Optional[cute.Tensor], + full_block_idx: Optional[cute.Tensor], + mask_block_cnt: Optional[cute.Tensor], + mask_block_idx: Optional[cute.Tensor], + buffers: Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) @@ -1663,6 +1837,20 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + softmax=softmax, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + ) while work_tile.is_valid_tile: # if work_tile.is_valid_tile: @@ -1671,18 +1859,31 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, - mask_causal=self.is_causal, mask_local=self.is_local, + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + buffers=buffers, ) score_mod_fn = None if const_expr(self.score_mod is not None): score_mod_fn = partial( self.apply_score_mod, - thr_mma_qk, batch_idx, head_idx, m_block, - softmax_scale=softmax_scale, buffers=buffers, fastdiv_mods=fastdiv_mods, + thr_mma_qk=thr_mma_qk, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + softmax_scale=softmax_scale, + buffers=buffers, + fastdiv_mods=fastdiv_mods, ) mma_one_n_block = partial( - mma_one_n_block_all, softmax=softmax, score_mod_fn=score_mod_fn + mma_one_n_block_all, + softmax=softmax, + score_mod_fn=score_mod_fn, ) # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): @@ -1705,87 +1906,226 @@ def mma( # We also need masking on S if it's causal, for the last several blocks. # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True O_should_accumulate = False - # First iteration with seqlen masking - if const_expr(self.intra_wg_overlap): - pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) - acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) - pipeline_k.consumer_release(kv_consumer_state) - # Use vectorized score modification - if cutlass.const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block_max - 1) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - softmax.online_softmax(acc_S, is_first=True) - tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - tOrP_cur.store(tOrP_acc.load().to(self.dtype)) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_thr_copy_P, tPrP, tPsP) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - # acc_O.fill(0.0) - else: - self.warp_scheduler_barrier_sync() - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1, - mma_pv_fn=partial(mma_pv_fn, zero_init=True), - is_first_n_block=True, - mask_fn=partial(mask_fn, mask_seqlen=True), - ) - O_should_accumulate = True - # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + kv_consumer_state=kv_consumer_state, + mask_fn=mask_fn, + is_first_block=True, + ) + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + # acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), + n_block=n_block_max - 1, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_seqlen=True), ) O_should_accumulate = True - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min ) - O_should_accumulate = True - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), ) O_should_accumulate = True - # Last "half" iteration - if const_expr(self.intra_wg_overlap): - pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) - mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0) - pipeline_v.consumer_release(kv_consumer_state) - kv_consumer_state.advance() + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + O_should_accumulate = True + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + else: + self.warp_scheduler_barrier_arrive() + else: - self.warp_scheduler_barrier_arrive() + # ========================================== + # Block sparsity + # ========================================== + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + # first masked and full blocks + mask_n_block = 0 + full_n_block = 0 + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + + if const_expr(not self.intra_wg_overlap): + # ========================================== + # Non-overlap path + # ========================================== + if curr_mask_block_cnt > 0: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + is_first_n_block=False, + ) + if curr_full_block_cnt == 0: + self.warp_scheduler_barrier_arrive() + + if curr_full_block_cnt > 0: + if curr_mask_block_cnt == 0: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=False, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + is_first_n_block=False, + ) + self.warp_scheduler_barrier_arrive() + else: + # ========================================== + # Overlap path + # ========================================== + + # Process first block + if curr_mask_block_cnt > 0: + kv_consumer_state = process_first_half_block( + n_block=mask_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + is_first_block=True, + ) + + # Process remaining mask blocks + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + + # Process full blocks + if curr_full_block_cnt > 0: + # If no mask blocks, first full block is the overall first + if curr_mask_block_cnt == 0: + kv_consumer_state = process_first_half_block( + n_block=full_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=None), + is_first_block=True, + ) + + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + ) + O_should_accumulate = True + + # Process remaining full blocks + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + ) + O_should_accumulate = True + + # Final PV gemm for last block + if curr_mask_block_cnt > 0 or curr_full_block_cnt > 0: + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + + if curr_mask_block_cnt + curr_full_block_cnt == 0: + softmax.reset() + acc_O.fill(0.0) + sink_val = None if const_expr(learnable_sink is not None): @@ -1815,6 +2155,74 @@ def mma( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def first_half_block_overlap( + self, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S=acc_S, n_block=n_block) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + # if pv gemm not rs + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + + # Advance state for next iteration + kv_consumer_state.advance() + + return kv_consumer_state + @cute.jit def mma_one_n_block( self, @@ -1840,10 +2248,13 @@ def mma_one_n_block( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) @@ -1899,12 +2310,14 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - if const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) @@ -1945,7 +2358,7 @@ def apply_score_mod( acc_S, n_block, softmax_scale, - buffers=None, + buffers=Optional[list[cute.Tensor]], fastdiv_mods=None, ): # Prepare index tensor diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 07a6c48bfbf..0615061a541 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype @@ -73,7 +74,12 @@ def _flash_attn_fwd( num_threads: int = 384, pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, - score_mod: Callable | None = None, + score_mod: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, @@ -135,7 +141,22 @@ def _flash_attn_fwd( if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device" + for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: + if t is not None: + assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" + assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" + assert all( + t is None or t.is_cuda + for t in ( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + page_table, + learnable_sink, + full_block_cnt, full_block_idx, + mask_block_cnt, mask_block_idx, + ) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -183,6 +204,13 @@ def _flash_attn_fwd( for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None + + full_block_cnt_tensor = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if full_block_cnt is not None else None + full_block_idx_tensor = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if full_block_idx is not None else None + mask_block_cnt_tensor = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if mask_block_cnt is not None else None + mask_block_idx_tensor = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if mask_block_idx is not None else None + + if causal: window_size_right = 0 local = window_size_left is not None or window_size_right is not None @@ -202,22 +230,44 @@ def _flash_attn_fwd( # TODO: fix the varlen case if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): pack_gqa = False - + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else None + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else None + if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) + is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None + use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None if score_mod is not None: - is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None if is_varlen: raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + if mask_mod is not None: + if not use_block_sparsity: + raise NotImplementedError("mask_mod requires the use of block sparsity. This will be fixed in a future PR.") + if is_varlen: + raise NotImplementedError("mask_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("mask_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + + if use_block_sparsity: + if is_varlen: + raise NotImplementedError("Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + cute_buffers = None if buffers is not None: cute_buffers = [from_dlpack(buf) for buf in buffers] compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, utils.hash_callable(score_mod) if score_mod is not None else None, + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, + score_mod_hash, mask_mod_hash, buffers is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, page_table is not None, @@ -245,6 +295,9 @@ def _flash_attn_fwd( num_stages=2, num_threads=num_threads, Q_in_regs=False, + intra_wg_overlap=True, + mma_pv_is_rs=True, + mask_mod=mask_mod, score_mod=score_mod, has_buffers=buffers is not None, ) @@ -264,18 +317,21 @@ def _flash_attn_fwd( else: raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") # TODO: check @can_implement - # TODO caching for buffers; cute_buffers _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, cute_buffers, + window_size_left, window_size_right, learnable_sink_tensor, + full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, + cute_buffers, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, cute_buffers + window_size_left, window_size_right, learnable_sink_tensor, + full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, + cute_buffers, ) return out, lse @@ -591,6 +647,11 @@ def forward( learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, pack_gqa: Optional[bool] = None, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): out, lse = _flash_attn_fwd( q, @@ -603,6 +664,11 @@ def forward( learnable_sink=learnable_sink, softcap=softcap, pack_gqa=pack_gqa, + mask_mod=mask_mod, + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale @@ -706,6 +772,11 @@ def flash_attn_func( learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, pack_gqa: Optional[bool] = None, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): return FlashAttnFunc.apply( q, @@ -717,6 +788,11 @@ def flash_attn_func( learnable_sink, softcap, pack_gqa, + mask_mod, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, ) @@ -973,4 +1049,4 @@ def flash_attn_combine( lse = None _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) - return out, lse + return out, lse \ No newline at end of file diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 25c69a69bc0..0d78eb9e948 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional +from typing import Optional, Callable from dataclasses import dataclass import cutlass @@ -9,7 +9,6 @@ import flash_attn.cute.utils as utils - @cute.jit def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: # Bit manipulation, compiles down to the R2P instruction @@ -39,7 +38,6 @@ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = Fal for r in cutlass.range_constexpr(cute.size(X.shape[0])): X[r, c] = X[r, c] if in_bound else -Float32.inf - @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -55,12 +53,16 @@ class AttentionMask: def apply_mask( self, acc_S: cute.Tensor, - m_block: Int32, - n_block: Int32, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + n_block: cutlass.Int32, thr_mma: cute.TiledMma, mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + buffers: Optional[list[cute.Tensor]] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) @@ -76,17 +78,55 @@ def apply_mask( COL = 1 if const_expr(not self.swap_AB) else 0 thr_col_offset = tScS_mn[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset - if const_expr(not mask_causal and not mask_local): + if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): # The compiler now choses not to use R2P r2p = const_expr(False and not self.swap_AB) if const_expr(not r2p): + # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) + + elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # FlexAttention mask mod + nrow = const_expr(cute.size(tScS_mn.shape[0])) + ncol = const_expr(cute.size(tScS_mn.shape[1])) + thr_col_offset = tScS_mn[0, 0][1] + + for r in cutlass.range_constexpr(nrow): + global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + + for col in cutlass.range_constexpr(ncol): + col_idx_local = t0ScS_mn[0, col][1] + # Convert to absolute column index + global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n + + cond = cutlass.Boolean( + mask_mod( + batch_idx, + head_idx, + tScS_mn[r, 0][0] + m_block * self.tile_m, + thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, + self.seqlen_q, + self.seqlen_k, + buffers, + ) + ) + if const_expr(mask_seqlen): + out_of_bounds = (global_row_idx >= self.seqlen_q) or ( + global_col_idx >= self.seqlen_k + ) + if out_of_bounds: + acc_S_mn[r, col] = -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + + else: # Causal or local if const_expr(not self.swap_AB): # If PackGQA, we split the work of compute divmod among threads in the same row @@ -303,9 +343,9 @@ def apply_mask_sm100_transposed( tidx = cute.arch.thread_idx()[0] % 128 seqlenk_row_limit = self.seqlen_k - n_block * self.tile_n - if cutlass.const_expr(not mask_causal and not mask_local): - if cutlass.const_expr(mask_seqlen): - ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if const_expr(not mask_causal and not mask_local): + if const_expr(mask_seqlen): + ncol = const_expr(cute.size(tScS_t2r.shape)) if tScS_t2r[0][0] >= seqlenk_row_limit: for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = -cutlass.Float32.inf @@ -313,12 +353,12 @@ def apply_mask_sm100_transposed( causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m row_idx = tScS_t2r[0][0] + n_block * self.tile_n - if cutlass.const_expr(mask_causal): + if const_expr(mask_causal): col_limit_left = row_idx + causal_row_offset - ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + ncol = const_expr(cute.size(tScS_t2r.shape)) # if tidx == 32 and wg_idx == 1: # cute.printf("row idx = {}, causal_row_offset = {}, col_limit_left = {}, first column = {}, last column = {} ", row_idx, causal_row_offset, col_limit_left, tScS_t2r[0][1], tScS_t2r[ncol - 1][1]) - if cutlass.const_expr(mask_seqlen): + if const_expr(mask_seqlen): if tScS_t2r[0][0] >= seqlenk_row_limit: col_limit_left = self.tile_m for i in cutlass.range(ncol, unroll_full=True): diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py new file mode 100644 index 00000000000..6b206fd6026 --- /dev/null +++ b/flash_attn/cute/mask_definitions.py @@ -0,0 +1,220 @@ +from typing import Callable, Optional + +import random +import math + +import cutlass +import cutlass.cute as cute +import torch + + +MaskModCallable = Optional[ + Callable[ + ["cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32"], + "cutlass.Boolean", + ] +] + + +# Flex Attention mask functions (PyTorch signatures for reference implementation) + + +def flex_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + if torch.is_tensor(q_idx): + return torch.ones_like(q_idx, dtype=torch.bool) + return True + + +def flex_identity_partial_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + if torch.is_tensor(q_idx): + return torch.ones_like(q_idx, dtype=torch.bool) + return True + + +def flex_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + # Right-aligned causal masking + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + return kv_idx <= q_idx + offset + return kv_idx <= q_idx + + +def flex_block_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + # Right-aligned causal masking + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + return kv_idx <= q_idx + offset + return kv_idx <= q_idx + + +def create_flex_sliding_window_mask(window_size=1024): + """Factory function to create a sliding window mask with configurable window size""" + def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + # Sliding window: q_idx - window_size <= kv_idx <= q_idx + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + return (kv_idx <= q_idx + offset) & (kv_idx >= q_idx + offset - window_size) + return (kv_idx <= q_idx) & (kv_idx >= q_idx - window_size) + return flex_sliding_window_mask + + +# Default sliding window mask with window_size=1024 for backward compatibility +def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + window_size = 1024 + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + # Sliding window: q_pos - window_size < kv_pos <= q_pos + # Note: using strict inequality on the left to match typical sliding window behavior + return (kv_idx <= q_idx + offset) & (kv_idx > q_idx + offset - window_size) + return (kv_idx <= q_idx) & (kv_idx > q_idx - window_size) + + +def flex_block_diagonal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None, block_size=64): + return (q_idx // block_size) == (kv_idx // block_size) + + +def flex_mini_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + return (q_idx % 128) >= (kv_idx % 128) + + +def flex_half_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + """Even k-blocks are full blocks, odd k-blocks are masked blocks (both return True)""" + if torch.is_tensor(kv_idx): + return torch.ones_like(kv_idx, dtype=torch.bool) + return True + +def flex_document_mask(b, h, q_idx, kv_idx, doc_id: torch.Tensor): + return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] + +# CuTe versions for kernel compilation + + +@cute.jit +def cute_identity_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + return cutlass.Boolean(True) + + +@cute.jit +def cute_identity_partial_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + return cutlass.Boolean(True) + + +@cute.jit +def cute_causal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + # Right-aligned causal masking + offset = seqlen_k - seqlen_q + return cutlass.Boolean(n_idx <= m_idx + offset) + + +@cute.jit +def cute_block_causal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + # Right-aligned causal masking + offset = seqlen_k - seqlen_q + return cutlass.Boolean(n_idx <= m_idx + offset) + + +def create_cute_sliding_window_mask(window_size=1024): + """Factory function to create a CuTe sliding window mask with configurable window size""" + @cute.jit + def cute_sliding_window_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + ) -> cutlass.Boolean: + offset = seqlen_k - seqlen_q + + return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + return cute_sliding_window_mask + + +# Default sliding window mask with window_size=1024 for backward compatibility +@cute.jit +def cute_sliding_window_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers +) -> cutlass.Boolean: + window_size = 1024 + # offset = seqlen_k - seqlen_q + offset = 0 + return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + + +@cute.jit +def cute_document_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: list, +): + doc_id = buffers[0] + return cutlass.Boolean(doc_id[batch, head, m_idx] == doc_id[batch, head, n_idx]) + + +@cute.jit +def cute_block_diagonal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers +) -> cutlass.Boolean: + return cutlass.Boolean((m_idx // 64) == (n_idx // 64)) + + +@cute.jit +def cute_mini_causal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers +) -> cutlass.Boolean: + """Each tile is locally causal-masked""" + m_mod = m_idx % 128 + n_mod = n_idx % 128 + return cutlass.Boolean(m_mod >= n_mod) + + +@cute.jit +def cute_half_identity_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32 +) -> cutlass.Boolean: + return cutlass.Boolean(True) + + +def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): + doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) + for b in range(batch): + for h in range(nheads): + N = seqlen_q + n = random.randint(1, math.ceil(math.sqrt(N // 4))) + cuts = sorted(random.sample(range(1, N), n-1)) + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] + + doc_ids = [] + for i, length in enumerate(lengths): + doc_ids += [i for _ in range(length)] + + doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) + print(f"{doc_ids_tensor.shape = }") + return doc_ids_tensor + + +MASK_FUNCTIONS = { + "identity": (cute_identity_mask, flex_identity_mask), + "identity_partial": (cute_identity_partial_mask, flex_identity_partial_mask), + "causal": (cute_causal_mask, flex_causal_mask), + "block_causal": (cute_block_causal_mask, flex_block_causal_mask), + "sliding_window": (cute_sliding_window_mask, flex_sliding_window_mask), + "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), + "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), + "half_identity": (cute_half_identity_mask, flex_half_identity_mask), + "document": (cute_document_mask, flex_document_mask), +} + +if __name__ == "__main__": + doc_ids = random_doc_id_tensor(1, 2, 128) + print(f"{doc_ids = }") \ No newline at end of file diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index a654e90d23e..644936d8d2d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -52,6 +52,8 @@ "seqlen_q,seqlen_k", [ (1, 1), + (3, 3), + (64, 32), (64, 128), (128, 192), (256, 256), @@ -82,6 +84,8 @@ def test_flash_attn_output( device = "cuda" # set seed torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 @@ -256,8 +260,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @@ -268,8 +272,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) -@pytest.mark.parametrize("d", [128, 192]) -# @pytest.mark.parametrize("d", [192]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1040,4 +1044,4 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): # Test with LSE not returned out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False) assert lse_no_lse is None, "LSE should be None when return_lse=False" - assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" \ No newline at end of file diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py new file mode 100644 index 00000000000..3e6707b5fb9 --- /dev/null +++ b/tests/cute/test_mask_mod.py @@ -0,0 +1,570 @@ +# mask mod test script + +import math + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +import torch.nn.functional as F + +from flash_attn.cute.block_sparsity import compute_block_sparsity +from flash_attn.cute.flash_fwd import ( + FlashAttentionForwardSm80, + FlashAttentionForwardSm90, +) +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.mask_definitions import MASK_FUNCTIONS, flex_causal_mask, create_flex_sliding_window_mask, create_cute_sliding_window_mask +from flash_attn.cute.testing import attention_ref + + +def create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype +): + device = "cuda" + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype) + k = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype + ) + out = torch.empty( + batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype + ) + lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32) + + return { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + +def compile_and_run_kernel( + tensors, + mask_mod_cute, + causal, + is_local, + window_left, + window_right, + tile_m, + tile_n, + full_block_cnt=None, + full_block_idx=None, + mask_block_cnt=None, + mask_block_idx=None, +): + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + cute_dtype = dtype_map[tensors["q"].dtype] + + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + headdim_v = tensors["v"].shape[-1] + + compute_capability = torch.cuda.get_device_capability() + if compute_capability >= (10, 0): + kernel_class = FlashAttentionForwardSm100 + elif compute_capability >= (9, 0): + kernel_class = FlashAttentionForwardSm90 + else: + kernel_class = FlashAttentionForwardSm80 + + qhead_per_kvhead = nheads // nheads_kv + kernel = kernel_class( + cute_dtype, + headdim, + headdim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=is_local, + pack_gqa=False, + tile_m=tile_m, + tile_n=tile_n, + num_stages=2, + num_threads=384, + intra_wg_overlap=True, + mma_pv_is_rs=True, + mask_mod=mask_mod_cute, + has_buffers=False, + Q_in_regs=False, + ) + + softmax_scale = 1.0 / math.sqrt(headdim) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["q"].ndim - 1 + ) + k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["k"].ndim - 1 + ) + v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["v"].ndim - 1 + ) + out_cute = from_dlpack( + tensors["out"].detach(), assumed_align=16 + ).mark_layout_dynamic(leading_dim=tensors["out"].ndim - 1) + lse_cute = from_dlpack( + tensors["lse"].detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=tensors["lse"].ndim - 1) + + full_block_cnt_cute = ( + from_dlpack(full_block_cnt.detach(), assumed_align=4) + if full_block_cnt is not None + else None + ) + full_block_idx_cute = ( + from_dlpack(full_block_idx.detach(), assumed_align=4) + if full_block_idx is not None + else None + ) + mask_block_cnt_cute = ( + from_dlpack(mask_block_cnt.detach(), assumed_align=4) + if mask_block_cnt is not None + else None + ) + mask_block_idx_cute = ( + from_dlpack(mask_block_idx.detach(), assumed_align=4) + if mask_block_idx is not None + else None + ) + + # Window parameters for is_local + window_left_cute = ( + cutlass.Int32(window_left) if window_left is not None else None + ) + window_right_cute = ( + cutlass.Int32(window_right) if window_right is not None else None + ) + + compiled = cute.compile( + kernel, + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + None, # learnable_sink + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + None, # buffers + ) + + compiled( + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + None, # learnable_sink + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + None, # buffers + ) + + torch.cuda.synchronize() + return tensors["out"] + + +def compute_reference_flash_attn( + tensors, causal, window_size, dtype_ref, upcast=True +): + """Compute reference using FlashAttention's attention_ref function""" + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + + q = tensors["q"].to(dtype_ref) + k = tensors["k"].to(dtype_ref) + v = tensors["v"].to(dtype_ref) + + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=causal, + window_size=window_size, + upcast=upcast, + reorder_ops=False, + ) + + return out_ref + + +def compute_reference_flex_attn( + tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n +): + """Compute reference using flex_attention for custom mask_mods""" + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + + q = tensors["q"].transpose(1, 2) + k = tensors["k"].transpose(1, 2) + v = tensors["v"].transpose(1, 2) + + if nheads != nheads_kv: + repeat_factor = nheads // nheads_kv + k = k.repeat_interleave(repeat_factor, dim=1) + v = v.repeat_interleave(repeat_factor, dim=1) + + scale = 1.0 / math.sqrt(headdim) + + # Handle identity (no masking) case + if mask_mod_flex is None: + out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale) + return out_ref.transpose(1, 2).contiguous() + + # Wrap mask_mod_flex to pass seqlen_q and seqlen_k + def mask_fn(b, h, q_idx, kv_idx): + return mask_mod_flex(b, h, q_idx, kv_idx, seqlen_q, seqlen_k) + + if mask_mod_name == "block_causal": + n_blocks_q = (seqlen_q + tile_m - 1) // tile_m + n_blocks_k = (seqlen_k + tile_n - 1) // tile_n + + mask = torch.zeros(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device) + + for q_block in range(n_blocks_q): + q_start = q_block * tile_m + q_end = min((q_block + 1) * tile_m, seqlen_q) + for k_block in range(n_blocks_k): + if k_block <= q_block: + k_start = k_block * tile_n + k_end = min((k_block + 1) * tile_n, seqlen_k) + mask[q_start:q_end, k_start:k_end] = True + + attn_mask = ( + mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) + ) + out_ref = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, scale=scale + ) + else: + block_mask = create_block_mask( + mask_fn, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + ).to(q.device) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) + + return out_ref.transpose(1, 2).contiguous() + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize("nheads", [4, 16, 32]) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) +# @pytest.mark.parametrize("headdim", [64, 128]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "use_mask_mod,is_local,mask_name,window_size,window_left,window_right", + [ + (False, False, "identity", None, None, None), + (False, False, "causal", None, None, None), + (True, False, "identity", None, None, None), + (True, False, "causal", None, None, None), + # (True, False, "block_causal", None, None, None), + # Mask mod sliding window + (True, False, "sliding_window", 128, None, None), + (True, False, "sliding_window", 256, None, None), + (True, False, "sliding_window", 512, None, None), + # Base local attention + # (False, True, None, None, 128, 0), + # (False, True, None, None, 256, 0), + # (False, True, None, None, 512, 0), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128),]) +def test_mask_mod_output( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, + use_mask_mod, is_local, mask_name, window_size, window_left, window_right, + tile_m, tile_n +): + torch.manual_seed(42) + + # Validate configuration + if is_local: + assert not use_mask_mod, "Cannot use both is_local and use_mask_mod" + assert window_left is not None or window_right is not None, \ + "Must specify window_left or window_right for is_local" + + if use_mask_mod and mask_name == "sliding_window": + assert window_size is not None, "window_size must be specified for sliding_window" + # Skip if seqlen_k is too small for the window + # if seqlen_k < window_size // 2: + # pytest.skip(f"seqlen_k={seqlen_k} too small for window_size={window_size}") + # Skip if seqlen_q > seqlen_k (problematic for sliding window) + if seqlen_q > seqlen_k: + pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window") + + if is_local: + window_left_val = window_left if window_left is not None else 0 + window_right_val = window_right if window_right is not None else 0 + total_window = window_left_val + window_right_val + 1 + # Skip if seqlen_k is too small for the window + if seqlen_k < total_window // 2: + pytest.skip(f"seqlen_k={seqlen_k} too small for window={total_window}") + # Skip if seqlen_q > seqlen_k (problematic for local window) + if seqlen_q > seqlen_k: + pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local") + + # Determine nheads_kv based on mode + if kv_mode == "mha": + nheads_kv = nheads + elif kv_mode == "gqa": + nheads_kv = nheads // 2 + elif kv_mode == "mqa": + nheads_kv = 1 + else: + raise ValueError(f"Unknown kv_mode: {kv_mode}") + + batch_size = 2 + headdim_v = headdim + + # Determine mask_mod functions and causal flag + if use_mask_mod: + if mask_name == "sliding_window": + # Use factory function for custom window size + mask_mod_cute = create_cute_sliding_window_mask(window_size) + mask_mod_flex = create_flex_sliding_window_mask(window_size) + else: + mask_mod_cute, mask_mod_flex = MASK_FUNCTIONS[mask_name] + causal = (mask_name == "causal") + elif is_local: + # Base local attention - no mask_mod + mask_mod_cute = None + mask_mod_flex = None + causal = False + else: + mask_mod_cute = None + mask_mod_flex = None + causal = (mask_name == "causal") if mask_name else False + + if causal and seqlen_k < seqlen_q: + pytest.skip("causal masking requires seqlen_k >= seqlen_q") + + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype + ) + + # Compute block sparsity for mask_mod + full_cnt, full_idx, mask_cnt, mask_idx = None, None, None, None + if use_mask_mod: + from dataclasses import dataclass + + @dataclass + class Config: + seqlen_q: int + seqlen_k: int + nheads: int + nheads_kv: int + batch_size: int + tile_m: int + tile_n: int + use_mask_mod: bool + mask_mod_name: str + window_size: int = 1024 + verbose: bool = False + + config = Config( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + nheads_kv=nheads_kv, + batch_size=batch_size, + tile_m=tile_m, + tile_n=tile_n, + use_mask_mod=True, + mask_mod_name=mask_name, + window_size=window_size if window_size is not None else 1024, + ) + + full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( + config=config, mask_mod_flex=mask_mod_flex, device="cuda" + ) + + # Run kernel + out_cute = compile_and_run_kernel( + tensors, + mask_mod_cute, + causal=causal, + is_local=is_local, + window_left=window_left, + window_right=window_right, + tile_m=tile_m, + tile_n=tile_n, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + ) + + # Determine which reference implementation to use + dtype_ref = torch.bfloat16 + use_flash_attn_ref = False + + # Use FlashAttention reference for causal and local window cases + if mask_name == "causal" and not use_mask_mod: + use_flash_attn_ref = True + window_size_ref = (None, None) # attention_ref handles causal internally + elif mask_name == "identity" and not use_mask_mod and not is_local: + use_flash_attn_ref = True + window_size_ref = (None, None) # No window for identity + elif is_local: + use_flash_attn_ref = True + # For is_local, we need to pass the window parameters + # When window_right=0, this is inherently causal + window_size_ref = (window_left, window_right) + if window_right == 0: + causal = True # Override causal flag for reference computation + elif use_mask_mod and mask_name == "sliding_window": + use_flash_attn_ref = True + # For sliding window mask_mod, window_size corresponds directly to window_left + # in attention_ref (number of previous tokens that can be attended to) + # Sliding window with window_right=0 is inherently causal + window_size_ref = (window_size, 0) + causal = True # Override causal flag for reference computation + + if use_flash_attn_ref: + # Compute reference using FlashAttention's attention_ref + out_ref_fp32 = compute_reference_flash_attn( + tensors, causal=causal, window_size=window_size_ref, dtype_ref=torch.float32, upcast=True + ) + out_ref = compute_reference_flash_attn( + tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype_ref, upcast=False + ) + + # Also compute PyTorch reference for comparison (with reorder_ops for better accuracy) + out_pt = compute_reference_flash_attn( + tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype, upcast=False + ) + else: + # Use flex_attention for custom mask_mods + tensors_fp32 = { + k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v + for k, v in tensors.items() + } + + out_ref_fp32 = compute_reference_flex_attn( + tensors_fp32, mask_mod_flex, mask_name, tile_m, tile_n + ) + out_ref = compute_reference_flex_attn( + tensors, mask_mod_flex, mask_name, tile_m, tile_n + ) + out_pt = out_ref.clone() + + # Check for invalid values + assert out_cute.shape == out_ref_fp32.shape == out_ref.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + + # Compute numerical tolerance (matching flash attention tests) + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + ref_error = (out_ref - out_ref_fp32).abs().max().item() + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + # Build description string + if is_local: + mask_desc = f"is_local(L={window_left},R={window_right})" + elif use_mask_mod: + mask_desc = f"mask_mod={mask_name}" + if mask_name == "sliding_window" and window_size is not None: + mask_desc += f"(w={window_size})" + else: + mask_desc = mask_name if mask_name else "identity" + + print( + f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " + f"D={headdim}, M={tile_m}, N={tile_n}" + ) + print(f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}") + print(f" Reference vs FP32: {ref_error:.2e}") + print(f" PyTorch vs FP32: {pt_error:.2e}") + print(f" Kernel vs FP32: {cute_error:.2e}") + print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") + print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") + + # Debug: show some sample values if error is large + if cute_error > 1e-2: + print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") + print(f" DEBUG: Sample reference output: {out_ref_fp32[0, 0, 0, :5]}") + print(f" DEBUG: Max diff location: {(out_cute - out_ref_fp32).abs().argmax()}") + max_diff_idx = (out_cute - out_ref_fp32).abs().argmax() + max_diff_coords = torch.unravel_index(max_diff_idx, out_cute.shape) + print(f" DEBUG: Max diff at coords: {max_diff_coords}") + print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") + print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") + + # Use the same assertion logic as FlashAttention tests + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) \ No newline at end of file