diff --git a/flash_attn/cute/barrier.py b/flash_attn/cute/barrier.py index 744e3a56507..c999b180167 100644 --- a/flash_attn/cute/barrier.py +++ b/flash_attn/cute/barrier.py @@ -4,8 +4,9 @@ from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import llvm + @dsl_user_op -def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() state = llvm.inline_asm( T.i32(), @@ -18,8 +19,11 @@ def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: ) return cutlass.Int32(state) + @dsl_user_op -def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None: +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, @@ -31,8 +35,11 @@ def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N asm_dialect=llvm.AsmDialect.AD_ATT, ) + @dsl_user_op -def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None: +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, @@ -43,28 +50,22 @@ def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) - + + @cute.jit -def wait_eq( - lock_ptr : cute.Pointer, - thread_idx : int | Int32, - flag_offset : int, - val : Int32 -) -> None: +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: flag_ptr = lock_ptr + flag_offset if thread_idx == 0: read_val = Int32(0) while read_val != val: read_val = ld_acquire(flag_ptr) + @cute.jit def arrive_inc( - lock_ptr : cute.Pointer, - thread_idx : int | Int32, - flag_offset : int, - val : cutlass.Constexpr[Int32] + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] ) -> None: flag_ptr = lock_ptr + flag_offset if thread_idx == 0: red_release(flag_ptr, val) - # red_relaxed(flag_ptr, val) \ No newline at end of file + # red_relaxed(flag_ptr, val) diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py index 071b4e02a58..b1aadd89395 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -5,7 +5,6 @@ from dataclasses import dataclass import math -from pickle import FALSE from typing import Any, Dict, Optional, Tuple import cuda.bindings.driver as cuda @@ -51,7 +50,7 @@ class BenchmarkConfig: # Mask parameters use_mask_mod: bool = True mask_mod_name: str = "causal" - has_buffers: bool = mask_mod_name == "document" + has_aux_tensors: bool = mask_mod_name == "document" # Sliding window parameter (used when mask_mod_name == "sliding_window") window_size: int = 128 @@ -235,7 +234,6 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: dtype=torch.float32, device=device, ) - tensors = { "q": q.contiguous(), @@ -244,10 +242,10 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: "out": out.contiguous(), "lse": lse.contiguous(), } - + if config.use_learnable_sink: learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) - + tensors["learnable_sink"] = learnable_sink.contiguous() # Compute block sparsity when using mask_mod @@ -256,14 +254,14 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: doc_id = random_doc_id_tensor( config.batch_size, config.nheads, config.seqlen_q, device=device ) - tensors["buffers"] = [doc_id.contiguous()] + tensors["aux_tensors"] = [doc_id.contiguous()] full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( config=self.config, mask_mod_flex=self.mask_mod_flex, device=device, cu_seqlens_q=tensors.get("cu_seqlens_q"), cu_seqlens_k=tensors.get("cu_seqlens_k"), - buffers=tensors.get("buffers"), + aux_tensors=tensors.get("aux_tensors"), ) if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): @@ -329,7 +327,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] mma_pv_is_rs=config.mma_pv_is_rs, mask_mod=self.mask_mod_cute, Q_in_regs=False, - has_buffers=config.has_buffers, + has_aux_tensors=config.has_aux_tensors, ) softmax_scale = 1.0 / math.sqrt(config.headdim) @@ -405,14 +403,14 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] else None ) - if "buffers" in tensors: - buffers_cute = [] - for i in range(len(tensors["buffers"])): - buf = from_dlpack(tensors["buffers"][i].detach(), assumed_align=4) - buffers_cute.append(buf.mark_layout_dynamic(leading_dim=2)) + if "aux_tensors" in tensors: + aux_tensors_cute = [] + for i in range(len(tensors["aux_tensors"])): + buf = from_dlpack(tensors["aux_tensors"][i].detach(), assumed_align=4) + aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2)) else: - buffers_cute = None + aux_tensors_cute = None # Window parameters for is_local window_left_cute = ( @@ -443,7 +441,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] full_block_idx_cute, mask_block_cnt_cute, mask_block_idx_cute, - buffers_cute, + aux_tensors_cute, # None, ) @@ -467,7 +465,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] full_block_idx_cute, mask_block_cnt_cute, mask_block_idx_cute, - buffers_cute, + aux_tensors_cute, # None, ) @@ -496,7 +494,7 @@ def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: num_blocks = (config.seqlen_k + block_size - 1) // block_size sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 elif config.mask_mod_name == "document": - vals = tensors["buffers"][0] + vals = tensors["aux_tensors"][0] val_mask = torch.ones_like(vals, dtype=torch.bool) val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] total = torch.where(val_mask, vals.square(), 0).sum() @@ -573,7 +571,7 @@ def benchmark(self) -> Dict[str, Any]: torch.cuda.synchronize() times.append(start.elapsed_time(end)) - + times_tensor = torch.tensor(times) mean_time = times_tensor.mean().item() std_time = times_tensor.std().item() if len(times) > 1 else 0.0 @@ -683,7 +681,7 @@ def _print_results(self, results: Dict[str, Any]): # seqlen_k=192, use_varlen=False, use_mask_mod=True, - mask_mod_name="identity", + mask_mod_name="causal", window_size=128, # Configurable window size for mask_mod use_learnable_sink=False, causal=False, diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index ce05cae1438..be685dea5d4 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -14,14 +14,17 @@ # placeholder Config = type("Config", (), {}) + def compute_block_sparsity( config: Config, mask_mod_flex: Optional[Callable], device: str, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, - buffers: Optional[List[torch.Tensor]] = None, -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + aux_tensors: Optional[List[torch.Tensor]] = None, +) -> Tuple[ + Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] +]: """ Computes block sparsity tensors from a given masking function. @@ -35,7 +38,7 @@ def compute_block_sparsity( device: The device to create tensors on (e.g., 'cuda'). cu_seqlens_q: Cumulative sequence lengths for Q (for varlen). cu_seqlens_k: Cumulative sequence lengths for K (for varlen). - buffers: A list of auxiliary tensors, e.g., for document masking. + aux_tensors: A list of auxiliary tensors, e.g., for document masking. Returns: A tuple of four tensors: @@ -53,25 +56,35 @@ def compute_block_sparsity( return _compute_varlen_sparsity(config, mask_mod_flex, device, cu_seqlens_q, cu_seqlens_k) else: # Handle fixed-length sequences - return _compute_sparsity(config, device, buffers) + return _compute_sparsity(config, device, aux_tensors) + ## --------------------------------------------------------------------------- ## Fixed-Length Sequence Kernels ## --------------------------------------------------------------------------- + def _compute_sparsity( - config: Config, device: str, buffers: Optional[List[torch.Tensor]] + config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Computes block sparsity for fixed-length sequences.""" n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n - + # Pre-allocate output tensors - full_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) - mask_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) - full_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) - mask_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) - + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32 + ) + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32 + ) + # --- Identity Mask --- # All blocks are fully computed. if config.mask_mod_name == "identity": @@ -79,7 +92,7 @@ def _compute_sparsity( for q_block_idx in range(n_blocks_q): full_block_cnt[:, :, q_block_idx] = n_blocks_k full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks - + # --- Identity Partial Mask --- # All blocks are partially computed (masked). elif config.mask_mod_name == "identity_partial": @@ -104,26 +117,34 @@ def _compute_sparsity( k_block_indices = torch.arange(n_blocks_k, device=device) q_starts = q_block_indices * config.tile_m - q_ends = torch.minimum((q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device)) + q_ends = torch.minimum( + (q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device) + ) k_starts = k_block_indices * config.tile_n - k_ends = torch.minimum((k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device)) + k_ends = torch.minimum( + (k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device) + ) # Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k) q_starts, q_ends = q_starts.unsqueeze(1), q_ends.unsqueeze(1) k_starts, k_ends = k_starts.unsqueeze(0), k_ends.unsqueeze(0) - + offset = config.seqlen_k - config.seqlen_q if config.mask_mod_name == "causal": is_full = (k_ends - 1) <= (q_starts + offset) # min(k_pos) <= max(q_pos) AND not is_full. is_partial = (k_starts <= (q_ends - 1 + offset)) & ~is_full - - else: # sliding_window - window_size = getattr(config, 'window_size', 1024) - is_full = (k_ends - 1 <= q_starts + offset) & (k_starts >= q_ends - 1 + offset - (window_size - 1)) + + else: # sliding_window + window_size = getattr(config, "window_size", 1024) + is_full = (k_ends - 1 <= q_starts + offset) & ( + k_starts >= q_ends - 1 + offset - (window_size - 1) + ) # A block is EMPTY if no (q, k) pairs satisfy the constraint. - is_empty = (k_starts > q_ends - 1 + offset) | (k_ends - 1 < q_starts + offset - (window_size - 1)) + is_empty = (k_starts > q_ends - 1 + offset) | ( + k_ends - 1 < q_starts + offset - (window_size - 1) + ) # A block is PARTIAL if it's not empty and not full. is_partial = ~is_empty & ~is_full @@ -132,22 +153,24 @@ def _compute_sparsity( full_indices = k_block_indices[is_full[q_block_idx]] if len(full_indices) > 0: full_block_cnt[:, :, q_block_idx] = len(full_indices) - full_block_idx[:, :, q_block_idx, :len(full_indices)] = full_indices + full_block_idx[:, :, q_block_idx, : len(full_indices)] = full_indices partial_indices = k_block_indices[is_partial[q_block_idx]] if len(partial_indices) > 0: mask_block_cnt[:, :, q_block_idx] = len(partial_indices) - mask_block_idx[:, :, q_block_idx, :len(partial_indices)] = partial_indices - + mask_block_idx[:, :, q_block_idx, : len(partial_indices)] = partial_indices + elif config.mask_mod_name == "document": raise NotImplementedError("Block sparsity for document masking not yet implemented") return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + ## --------------------------------------------------------------------------- ## Variable-Length Sequence Kernels ## --------------------------------------------------------------------------- + def _compute_varlen_sparsity( config: Config, mask_mod_flex: Callable, @@ -159,7 +182,7 @@ def _compute_varlen_sparsity( assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention" assert cu_seqlens_q.shape[0] == config.batch_size + 1 assert cu_seqlens_k.shape[0] == config.batch_size + 1 - + # In varlen, each sequence can have a different number of Q blocks. # We pad up to the maximum number of Q blocks in the batch. max_m_blocks = 0 @@ -173,62 +196,98 @@ def _compute_varlen_sparsity( max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n # Pre-allocate padded output tensors - full_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) - mask_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) - full_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) - mask_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), + device=device, + dtype=torch.int32, + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), + device=device, + dtype=torch.int32, + ) # Process each sequence in the batch individually for seq_idx in range(config.batch_size): seq_start_q = cu_seqlens_q[seq_idx].item() seq_end_q = cu_seqlens_q[seq_idx + 1].item() seq_len_q = seq_end_q - seq_start_q - + seq_start_k = cu_seqlens_k[seq_idx].item() seq_end_k = cu_seqlens_k[seq_idx + 1].item() seq_len_k = seq_end_k - seq_start_k - + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n # Global block indices are relative to the start of the entire batch tensor first_m_block_global = seq_start_q // config.tile_m first_n_block_global = seq_start_k // config.tile_n - + common_args = { - "full_block_cnt": full_block_cnt, "full_block_idx": full_block_idx, - "mask_block_cnt": mask_block_cnt, "mask_block_idx": mask_block_idx, - "seq_idx": seq_idx, "n_blocks_q": n_blocks_q, "n_blocks_k": n_blocks_k, - "seq_start_q": seq_start_q, "seq_end_q": seq_end_q, - "seq_start_k": seq_start_k, "seq_end_k": seq_end_k, + "full_block_cnt": full_block_cnt, + "full_block_idx": full_block_idx, + "mask_block_cnt": mask_block_cnt, + "mask_block_idx": mask_block_idx, + "seq_idx": seq_idx, + "n_blocks_q": n_blocks_q, + "n_blocks_k": n_blocks_k, + "seq_start_q": seq_start_q, + "seq_end_q": seq_end_q, + "seq_start_k": seq_start_k, + "seq_end_k": seq_end_k, "first_n_block_global": first_n_block_global, - "tile_m": config.tile_m, "tile_n": config.tile_n, "device": device + "tile_m": config.tile_m, + "tile_n": config.tile_n, + "device": device, } if config.mask_mod_name == "causal": _compute_causal_varlen_blocks(**common_args) elif config.mask_mod_name == "sliding_window": - window_size = getattr(config, 'window_size', 1024) + window_size = getattr(config, "window_size", 1024) _compute_sliding_window_varlen_blocks(**common_args, window_size=window_size) elif config.mask_mod_name == "identity": _compute_identity_varlen_blocks( - full_block_cnt, full_block_idx, seq_idx, - n_blocks_q, n_blocks_k, first_n_block_global, device + full_block_cnt, + full_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + first_n_block_global, + device, ) else: # Generic case relies on sampling the user-provided mask function _compute_generic_varlen_blocks( - **common_args, mask_mod_flex=mask_mod_flex, - seq_len_q=seq_len_q, seq_len_k=seq_len_k, - num_heads=config.nheads, nheads_kv=config.nheads_kv, + **common_args, + mask_mod_flex=mask_mod_flex, + seq_len_q=seq_len_q, + seq_len_k=seq_len_k, + num_heads=config.nheads, + nheads_kv=config.nheads_kv, ) - + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + def _classify_varlen_block( - m_local: int, n_local: int, seq_start_q: int, seq_end_q: int, - seq_start_k: int, seq_end_k: int, tile_m: int, tile_n: int, - is_full_fn: Callable, is_partial_fn: Callable + m_local: int, + n_local: int, + seq_start_q: int, + seq_end_q: int, + seq_start_k: int, + seq_end_k: int, + tile_m: int, + tile_n: int, + is_full_fn: Callable, + is_partial_fn: Callable, ) -> Tuple[bool, bool]: """Helper to classify a varlen block as full, partial, or empty.""" m_start_global = seq_start_q + m_local * tile_m @@ -241,20 +300,35 @@ def _classify_varlen_block( m_end_local = m_end_global - seq_start_q n_start_local = n_start_global - seq_start_k n_end_local = n_end_global - seq_start_k - + is_full = is_full_fn(m_start_local, m_end_local, n_start_local, n_end_local) - is_partial = is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full - + is_partial = ( + is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full + ) + # Any block that touches the sequence boundary is partial because it requires masking. at_boundary = (m_end_global > seq_end_q) or (n_end_global > seq_end_k) - + return is_full and not at_boundary, is_partial or (is_full and at_boundary) + def _compute_causal_varlen_blocks( - full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, - seq_idx, n_blocks_q, n_blocks_k, - seq_start_q, seq_end_q, seq_start_k, seq_end_k, - first_n_block_global, tile_m, tile_n, device, **kwargs + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + first_n_block_global, + tile_m, + tile_n, + device, + **kwargs, ): """Computes causal block sparsity for a single varlen sequence.""" is_full_fn = lambda m_start, m_end, n_start, n_end: (m_start >= n_end - 1) @@ -264,8 +338,16 @@ def _compute_causal_varlen_blocks( full_blocks, partial_blocks = [], [] for n_local in range(n_blocks_k): is_full, is_partial = _classify_varlen_block( - m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, - tile_m, tile_n, is_full_fn, is_partial_fn + m_local, + n_local, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + tile_m, + tile_n, + is_full_fn, + is_partial_fn, ) n_block_global = first_n_block_global + n_local if is_full: @@ -275,98 +357,157 @@ def _compute_causal_varlen_blocks( if full_blocks: full_block_cnt[seq_idx, :, m_local] = len(full_blocks) - full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) if partial_blocks: mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) + def _compute_sliding_window_varlen_blocks( - full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, - seq_idx, n_blocks_q, n_blocks_k, - seq_start_q, seq_end_q, seq_start_k, seq_end_k, - first_n_block_global, tile_m, tile_n, window_size, device, **kwargs + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + first_n_block_global, + tile_m, + tile_n, + window_size, + device, + **kwargs, ): """Computes sliding window block sparsity for a single varlen sequence.""" - is_full_fn = lambda m_start, m_end, n_start, n_end: \ - (n_end - 1 <= m_start) and (n_start >= m_start - window_size + 1) - is_partial_fn = lambda m_start, m_end, n_start, n_end: \ - not ((n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1)) + is_full_fn = lambda m_start, m_end, n_start, n_end: (n_end - 1 <= m_start) and ( + n_start >= m_start - window_size + 1 + ) + is_partial_fn = lambda m_start, m_end, n_start, n_end: not ( + (n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1) + ) for m_local in range(n_blocks_q): full_blocks, partial_blocks = [], [] for n_local in range(n_blocks_k): is_full, is_partial = _classify_varlen_block( - m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, - tile_m, tile_n, is_full_fn, is_partial_fn + m_local, + n_local, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + tile_m, + tile_n, + is_full_fn, + is_partial_fn, ) n_block_global = first_n_block_global + n_local if is_full: full_blocks.append(n_block_global) elif is_partial: partial_blocks.append(n_block_global) - + if full_blocks: full_block_cnt[seq_idx, :, m_local] = len(full_blocks) - full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) if partial_blocks: mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) + def _compute_identity_varlen_blocks( - full_block_cnt, full_block_idx, seq_idx, n_blocks_q, - n_blocks_k, first_n_block_global, device, **kwargs + full_block_cnt, + full_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + first_n_block_global, + device, + **kwargs, ): """Computes identity (all-attend) block sparsity for a single varlen sequence.""" n_blocks_global = torch.arange( - first_n_block_global, first_n_block_global + n_blocks_k, - device=device, dtype=torch.int32 + first_n_block_global, first_n_block_global + n_blocks_k, device=device, dtype=torch.int32 ) for m_local in range(n_blocks_q): full_block_cnt[seq_idx, :, m_local] = n_blocks_k full_block_idx[seq_idx, :, m_local, :n_blocks_k] = n_blocks_global + def _compute_generic_varlen_blocks( - full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, - mask_mod_flex, seq_idx, num_heads, n_blocks_q, n_blocks_k, - seq_len_q, seq_len_k, first_n_block_global, - tile_m, tile_n, nheads_kv, device, **kwargs + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + mask_mod_flex, + seq_idx, + num_heads, + n_blocks_q, + n_blocks_k, + seq_len_q, + seq_len_k, + first_n_block_global, + tile_m, + tile_n, + nheads_kv, + device, + **kwargs, ): """Generic sampling-based block classification for a varlen sequence.""" qhead_per_kvhead = num_heads // nheads_kv - + for h_q in range(num_heads): h_kv = h_q // qhead_per_kvhead for m_local in range(n_blocks_q): m_start_local = m_local * tile_m m_end_local = min((m_local + 1) * tile_m, seq_len_q) - + full_blocks, partial_blocks = [], [] for n_local in range(n_blocks_k): n_start_local = n_local * tile_n n_end_local = min((n_local + 1) * tile_n, seq_len_k) - + # Sample points within the block (corners and center) to classify it. # Coordinates are sequence-local, as required by mask_mod_flex. sample_positions = [ - (m_start_local, n_start_local), (m_start_local, n_end_local - 1), - (m_end_local - 1, n_start_local), (m_end_local - 1, n_end_local - 1), + (m_start_local, n_start_local), + (m_start_local, n_end_local - 1), + (m_end_local - 1, n_start_local), + (m_end_local - 1, n_end_local - 1), ((m_start_local + m_end_local) // 2, (n_start_local + n_end_local) // 2), ] - + unmasked_count = sum( - 1 for q_pos, k_pos in sample_positions + 1 + for q_pos, k_pos in sample_positions if mask_mod_flex(seq_idx, h_q, q_pos, k_pos, seq_len_q, seq_len_k) ) - + n_block_global = first_n_block_global + n_local - if unmasked_count == len(sample_positions): # All samples unmasked -> full + if unmasked_count == len(sample_positions): # All samples unmasked -> full full_blocks.append(n_block_global) - elif unmasked_count > 0: # Some unmasked -> partial + elif unmasked_count > 0: # Some unmasked -> partial partial_blocks.append(n_block_global) - + if full_blocks: full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks) - full_block_idx[seq_idx, h_q, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) if partial_blocks: mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, h_q, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) \ No newline at end of file + mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4922a1534c9..b49a693dfcd 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -32,12 +32,17 @@ from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) from flash_attn.cute.fast_math import FastDivmod class FlashAttentionForwardBase: - arch: int = 80 def __init__( @@ -56,7 +61,7 @@ def __init__( Q_in_regs: bool = False, score_mod: Optional[cutlass.Constexpr] = None, mask_mod: Optional[cutlass.Constexpr] = None, - has_buffers: bool = False, + has_aux_tensors: bool = False, ): """Initializes the configuration for a flash attention kernel. @@ -73,9 +78,9 @@ def __init__( :type num_threads: int :param is_causal: is causal :param score_mod: A callable that takes the attention scores and applies a modification. - Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any`` + Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any`` :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. - Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, buffers) -> Boolean`` + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -99,15 +104,22 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if const_expr(has_buffers): + if const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 2 @staticmethod def can_implement( - dtype, head_dim, head_dim_v, tile_m, tile_n, num_stages, num_threads, is_causal, - Q_in_regs=False + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, ) -> bool: """Check if the kernel can be implemented with the given parameters. @@ -142,7 +154,9 @@ def can_implement( smem_usage_Q = tile_m * head_dim * 2 smem_usage_K = tile_n * head_dim * num_stages * 2 smem_usage_V = tile_n * head_dim_v * num_stages * 2 - smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") @@ -186,22 +200,34 @@ def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: Q/K/V # /////////////////////////////////////////////////////////////////////////////// - sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = self._get_smem_layout_atom() + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = ( + self._get_smem_layout_atom() + ) self.sQ_layout = cute.tile_to_shape( - sQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1), + sQ_layout_atom, + (self.tile_m, self.tile_hdim), + (0, 1), ) self.sK_layout = cute.tile_to_shape( - sK_layout_atom, (self.tile_n, self.tile_hdim, self.num_stages), (0, 1, 2), + sK_layout_atom, + (self.tile_n, self.tile_hdim, self.num_stages), + (0, 1, 2), ) self.sV_layout = cute.tile_to_shape( - sV_layout_atom, (self.tile_n, self.tile_hdimv, self.num_stages), (0, 1, 2), + sV_layout_atom, + (self.tile_n, self.tile_hdimv, self.num_stages), + (0, 1, 2), ) self.sO_layout = cute.tile_to_shape( - sO_layout_atom, (self.tile_m, self.tile_hdimv), (0, 1), + sO_layout_atom, + (self.tile_m, self.tile_hdimv), + (0, 1), ) if const_expr(sP_layout_atom is not None): self.sP_layout = cute.tile_to_shape( - sP_layout_atom, (self.tile_m, self.tile_n), (0, 1), + sP_layout_atom, + (self.tile_m, self.tile_n), + (0, 1), ) else: self.sP_layout = None @@ -220,28 +246,38 @@ def _setup_attributes(self): ) # atom_universal_copy: universal copy atom for O store atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, ) # tQ_layout and tK_layout: thread layout for QK load tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems - assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" - assert self.num_producer_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + assert self.num_producer_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) tQ_layout = cute.make_ordered_layout( - (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), ) tK_layout = cute.make_ordered_layout( - (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we load Q assert self.tile_m % tQ_layout.shape[0] == 0 tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems tV_layout = cute.make_ordered_layout( - (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), ) # TODO: need a different layout for O if O dtype is not the same as V dtype # tO_layout: thread layout for O store tO_layout = cute.make_ordered_layout( - (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O assert self.tile_m % tO_layout.shape[0] == 0 @@ -304,7 +340,9 @@ def epilogue( rO = cute.make_fragment_like(acc_O, self.dtype) rO.store(acc_O.load().to(self.dtype)) # Make sure all threads have finished reading V - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) @@ -313,7 +351,9 @@ def epilogue( cute.copy(smem_copy_atom_O, taccOrO, taccOsO) cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) - pack_gqa = PackGQA(self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead + ) # Write LSE from rmem -> gmem if const_expr(mLSE is not None): @@ -336,7 +376,10 @@ def epilogue( # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): - if t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]: + if ( + t0accOcO[m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] + ): taccOgLSE[m, 0] = lse[m] else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) @@ -353,19 +396,28 @@ def epilogue( if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) - cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) store_O, _, _ = copy_utils.tma_get_copy_fn( tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) store_O() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) else: - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) tOrO = cute.make_fragment_like(tOsO, self.dtype) @@ -379,12 +431,17 @@ def epilogue( tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]: + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None], - pred=tOpO[None, rest_m, None] if const_expr(self.check_hdim_v_oob) else None, + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, ) else: pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @@ -452,7 +509,9 @@ def load_K( cute.copy( gmem_tiled_copy, tKgK[None, n, None, block], - tKsK[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + tKsK[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. @@ -483,7 +542,11 @@ def load_V( if const_expr(need_predicates or not is_even_n_smem_v): for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked - if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.tile_n: + if ( + is_even_n_smem_v + or n < cute.size(tVsV.shape[1]) - 1 + or tVcV[0, n, 0][0] < self.tile_n + ): predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None if const_expr(need_predicates): seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] @@ -491,11 +554,15 @@ def load_V( predicate = cute.make_fragment_like(tVpV[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): - predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n + predicate[i, k] = ( + tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True + ) and predicate_n cute.copy( gmem_tiled_copy, tVgV[None, n, None, block], - tVsV[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + tVsV[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], pred=predicate, ) else: @@ -508,7 +575,6 @@ def load_V( class FlashAttentionForwardSm80(FlashAttentionForwardBase): - def _get_smem_layout_atom(self): sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) sK_layout_atom = sQ_layout_atom @@ -564,7 +630,7 @@ def __call__( window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, - buffers=None, + aux_tensors=None, ): """Configures and launches the flash attention kernel. @@ -572,7 +638,9 @@ def __call__( (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ assert learnable_sink is None, "Learnable sink is not supported in this kernel" - self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))) + self._check_type( + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size self.num_producer_threads = self.num_threads @@ -583,9 +651,18 @@ def __call__( self._setup_attributes() SharedStorage = self._get_shared_storage_cls() # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) for t in (mQ, mK, mV, mO)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) + for t in (mQ, mK, mV, mO) + ] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( @@ -605,8 +682,10 @@ def __call__( softmax_scale = Float32(softmax_scale) fastdiv_mods = None - if const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -634,7 +713,7 @@ def __call__( tiled_mma_qk, tiled_mma_pv, SharedStorage, - buffers, + aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, @@ -667,7 +746,7 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, - buffers=None, + aux_tensors=None, fastdiv_mods=None, ): # Thread index, block index @@ -675,8 +754,12 @@ def kernel( m_block, num_head, batch_size = cute.arch.block_idx() block_info = BlockInfo( - self.tile_m, self.tile_n, self.is_causal, self.is_local, - window_size_left, window_size_right, + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) @@ -735,10 +818,12 @@ def kernel( # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_QK = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, ) smem_copy_atom_V = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self.dtype, ) smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) @@ -773,29 +858,49 @@ def kernel( tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) # shape: (atom_v_m * rest_m) - softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) softmax.reset() # group parameters for compute_one_n_block mma_params = SimpleNamespace( - thr_mma_qk=thr_mma_qk, thr_mma_pv=thr_mma_pv, - tSrQ=tSrQ, tSrK=tSrK, tOrVt=tOrVt, acc_O=acc_O, + thr_mma_qk=thr_mma_qk, + thr_mma_pv=thr_mma_pv, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, ) smem_copy_params = SimpleNamespace( smem_thr_copy_Q=smem_thr_copy_Q, smem_thr_copy_K=smem_thr_copy_K, smem_thr_copy_V=smem_thr_copy_V, - tSsQ=tSsQ, tSsK=tSsK, tOsVt=tOsVt, + tSsQ=tSsQ, + tSsK=tSsK, + tOsVt=tOsVt, + ) + load_K = partial( + self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k + ) + load_V = partial( + self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k ) - load_K = partial(self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, - seqlen=seqlen.seqlen_k) - load_V = partial(self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, - seqlen=seqlen.seqlen_k) compute_one_n_block = partial( - self.compute_one_n_block, mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, load_K=load_K, load_V=load_V, score_mod=self.score_mod, - batch_idx=batch_size, head_idx=num_head, m_block=m_block, buffers=buffers, + self.compute_one_n_block, + mma_params=mma_params, + smem_copy_params=smem_copy_params, + softmax=softmax, + load_K=load_K, + load_V=load_V, + score_mod=self.score_mod, + batch_idx=batch_size, + head_idx=num_head, + m_block=m_block, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -826,11 +931,11 @@ def preprocess_Q(): for stage in cutlass.range_constexpr(self.num_stages): if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: - load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() if const_expr(stage < self.num_stages - 1): if stage == 0 or n_block - stage >= 0: - load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() if const_expr(not self.Q_in_regs): preprocess_Q() @@ -844,20 +949,33 @@ def preprocess_Q(): # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. mask = AttentionMask( - self.tile_m, self.tile_n, seqlen.seqlen_q, seqlen.seqlen_k, - window_size_left, window_size_right, + self.tile_m, + self.tile_n, + seqlen.seqlen_q, + seqlen.seqlen_k, + window_size_left, + window_size_right, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, - mask_causal=self.is_causal, mask_local=self.is_local, + mask.apply_mask, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, ) # First iteration with seqlen masking smem_pipe_read = Int32(0) smem_pipe_write = Int32(self.num_stages - 1) - compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, is_first_n_block=True, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + is_first_n_block=True, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # Next couple of iterations with causal masking @@ -867,13 +985,20 @@ def preprocess_Q(): ) for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 2 - n_tile - compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=False)) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking for n_tile in cutlass.range(n_block, unroll=1): - compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True) + compute_one_n_block( + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # TODO: local @@ -888,8 +1013,19 @@ def preprocess_Q(): # reuse sQ's data iterator sO = cute.make_tensor(sQ.iterator, sO_layout) self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, - gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + None, + tiled_mma_pv, + tidx, + m_block, + num_head, + batch_size, ) @cute.jit @@ -907,7 +1043,7 @@ def compute_one_n_block( batch_idx: cutlass.Int32, head_idx: cutlass.Int32, m_block: cutlass.Int32, - buffers=None, + aux_tensors=None, fastdiv_mods=None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, @@ -918,6 +1054,7 @@ def compute_one_n_block( This function provides different variants for processing the first n block versus subsequent blocks. """ + def sync(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) cute.arch.barrier() @@ -927,18 +1064,29 @@ def sync(): acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S sync() + # need predicates for the first tile def load_V_next(): if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: - load_V(n_block - self.num_stages + 1, smem_pipe_write, - need_predicates=is_first_n_block and self.num_stages == 1) + load_V( + n_block - self.num_stages + 1, + smem_pipe_write, + need_predicates=is_first_n_block and self.num_stages == 1, + ) cute.arch.cp_async_commit_group() + load_V_next() sm80_utils.gemm( - mma_params.thr_mma_qk, acc_S, mma_params.tSrQ, mma_params.tSrK, + mma_params.thr_mma_qk, + acc_S, + mma_params.tSrQ, + mma_params.tSrK, smem_copy_params.tSsQ, - smem_copy_params.tSsK[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], - smem_copy_params.smem_thr_copy_Q, smem_copy_params.smem_thr_copy_K, + smem_copy_params.tSsK[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_Q, + smem_copy_params.smem_thr_copy_K, # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) @@ -951,15 +1099,17 @@ def load_V_next(): acc_S, n_block, softmax_scale=softmax.softmax_scale, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) smem_pipe_write = self.advance_pipeline(smem_pipe_write) + def load_K_next(): if n_block - self.num_stages >= 0: load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) cute.arch.cp_async_commit_group() + # wait for smem tile V for O if const_expr(self.num_stages == 1): sync() @@ -975,8 +1125,13 @@ def load_K_next(): sync() load_K_next() sm80_utils.gemm_rs( - mma_params.thr_mma_pv, mma_params.acc_O, tOrP, mma_params.tOrVt, - smem_copy_params.tOsVt[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], + mma_params.thr_mma_pv, + mma_params.acc_O, + tOrP, + mma_params.tOrVt, + smem_copy_params.tOsVt[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], smem_copy_params.smem_thr_copy_V, # hook_fn=load_K_next, ) @@ -985,7 +1140,6 @@ def load_K_next(): class FlashAttentionForwardSm90(FlashAttentionForwardBase): - arch = 90 def __init__( @@ -998,21 +1152,18 @@ def __init__( super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap self.mma_pv_is_rs = mma_pv_is_rs - def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim - ), - self.dtype + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv ), - self.dtype + self.dtype, ) sO_layout_atom = sV_layout_atom if not self.mma_pv_is_rs: @@ -1020,7 +1171,7 @@ def _get_smem_layout_atom(self): sm90_utils_basic.get_smem_layout_atom( LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n ), - self.dtype + self.dtype, ) else: sP_layout_atom = None @@ -1044,7 +1195,9 @@ def _get_tiled_mma(self): Float32, atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.tile_hdimv), - a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, ) tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -1054,7 +1207,7 @@ def _get_tiled_mma(self): Float32, atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.tile_hdimv), - a_source=warpgroup.OperandSource.RMEM + a_source=warpgroup.OperandSource.RMEM, ) return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs @@ -1066,8 +1219,8 @@ def _get_shared_storage_cls(self): sQ_struct, sK_struct, sV_struct = [ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], alignment] for layout, alignment in zip( - (self.sQ_layout, self.sK_layout, self.sV_layout), - (sQ_alignment, sK_alignment, sV_alignment) + (self.sQ_layout, self.sK_layout, self.sV_layout), + (sQ_alignment, sK_alignment, sV_alignment), ) ] cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) @@ -1122,7 +1275,7 @@ def __call__( full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) - buffers: Optional[list[cute.Tensor]] = None, + aux_tensors: Optional[list] = None, ): """Configures and launches the flash attention kernel. @@ -1131,14 +1284,22 @@ def __call__( """ self._check_type( - *(t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) ) # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] @@ -1164,10 +1325,20 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 - self.use_block_sparsity = const_expr(mask_block_cnt is not None and full_block_cnt is not None) - self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) - self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0) - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + self.use_block_sparsity = const_expr( + mask_block_cnt is not None and full_block_cnt is not None + ) + self.use_scheduler_barrier = ( + (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_mma_warp_groups == 2) + ) + self.use_tma_Q = self.arch >= 90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = ( + self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + ) # TODO: rescale_O_before_gemm self._setup_attributes() # TODO: we prob don't need most of what's in _setup_attributes @@ -1189,16 +1360,50 @@ def __call__( SharedStorage = self._get_shared_storage_cls() if const_expr(self.pack_gqa): - shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) - stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) - mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) - shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) - stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) - mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) if const_expr(mLSE is not None): - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) # TMA gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() @@ -1215,39 +1420,53 @@ def __call__( tma_atom_Q, tma_tensor_Q = None, None if const_expr(self.use_tma_Q): tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.tile_m, self.tile_hdim), # No mcast + gmem_tiled_copy_Q, + mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast ) tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), - 1 # No mcast for now + 1, # No mcast for now ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mV, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), - 1 # No mcast for now + 1, # No mcast for now ) tma_atom_O, tma_tensor_O = None, None if const_expr(self.use_tma_O): tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_O, mO, self.sO_layout, (self.tile_m, self.tile_hdimv), # No mcast + gmem_tiled_copy_O, + mO, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast ) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: - TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], - total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=(self.tile_m, self.tile_n), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, @@ -1274,8 +1493,10 @@ def __call__( window_size_right = Int32(window_size_right) fastdiv_mods = None - if const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -1319,7 +1540,7 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, - buffers, + aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, @@ -1369,7 +1590,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], - buffers=Optional[list[cute.Tensor]], + aux_tensors=Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1392,7 +1613,9 @@ def kernel( cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync - pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread + ) pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) @@ -1421,7 +1644,9 @@ def kernel( if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: - sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = utils.transpose_view(sV) sP = None @@ -1431,19 +1656,29 @@ def kernel( sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) block_info = BlockInfo( - self.tile_m, self.tile_n, self.is_causal, self.is_local, - window_size_left, window_size_right, + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + SeqlenInfoQK, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, - mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, ) AttentionMaskCls = partial( - AttentionMask, self.tile_m, self.tile_n, - window_size_left=window_size_left, window_size_right=window_size_right, + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) @@ -1509,7 +1744,7 @@ def kernel( full_block_idx, mask_block_cnt, mask_block_idx, - buffers, + aux_tensors, fastdiv_mods, ) @@ -1545,11 +1780,13 @@ def load( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: + # if work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) @@ -1561,12 +1798,15 @@ def load( ) # TODO: mcast # TODO check warp_idx if we have 128 producer threads - load_K, _, _ = copy_utils.tma_get_copy_fn(tma_atom_K, 0, cute.make_layout(1), gK, sK) + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK + ) load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) - load_V, _, _ = copy_utils.tma_get_copy_fn(tma_atom_V, 0, cute.make_layout(1), gV, sV) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV + ) load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # if cute.arch.thread_idx()[0] == 0: @@ -1575,7 +1815,9 @@ def load( n_block = n_block_max - 1 pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) @@ -1614,22 +1856,26 @@ def load( curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - + if const_expr(not self.intra_wg_overlap): if curr_mask_block_cnt > 0: # First mask block - load with Q n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1] pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_Q( + tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state) + ) load_K(src_idx=n_block_mask, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_mask, producer_state=kv_producer_state) kv_producer_state.advance() - + # Remaining mask blocks for i in cutlass.range(1, curr_mask_block_cnt): n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] @@ -1638,17 +1884,23 @@ def load( pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_mask, producer_state=kv_producer_state) kv_producer_state.advance() - + if curr_full_block_cnt > 0: n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] - if curr_mask_block_cnt == 0: + if curr_mask_block_cnt == 0: # must load Q if not loaded in mask loop pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_Q( + tma_bar_ptr=pipeline_k.producer_get_barrier( + kv_producer_state + ) + ) load_K(src_idx=n_block_full, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_full, producer_state=kv_producer_state) @@ -1666,28 +1918,32 @@ def load( pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_full, producer_state=kv_producer_state) kv_producer_state.advance() - + else: # ========================================== # Overlap path # ========================================== - + # Load Q with the first K block (whether mask or full) n_block_first = -1 if curr_mask_block_cnt > 0: n_block_first = curr_mask_block_idx[curr_mask_block_cnt - 1] elif curr_full_block_cnt > 0: n_block_first = curr_full_block_idx[curr_full_block_cnt - 1] - + if n_block_first >= 0: pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_Q( + tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state) + ) load_K(src_idx=n_block_first, producer_state=kv_producer_state) - + if curr_mask_block_cnt > 0: # Staggered loading for remaining mask blocks for i in cutlass.range(1, curr_mask_block_cnt): @@ -1698,8 +1954,10 @@ def load( pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_mask, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev) - + load_V( + src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev + ) + # Handle transition from mask to full blocks if curr_full_block_cnt > 0: # Load first full block K, last mask block V @@ -1710,14 +1968,16 @@ def load( pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_full, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + load_V( + src_idx=n_block_mask_last, producer_state=kv_producer_state_prev + ) else: # No full blocks, just load last mask block V n_block_mask_last = curr_mask_block_idx[0] pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state) kv_producer_state.advance() - + if curr_full_block_cnt > 0: # Staggered loading for remaining full blocks ( for j in cutlass.range(1, curr_full_block_cnt): @@ -1728,8 +1988,10 @@ def load( pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_full, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_full_prev, producer_state=kv_producer_state_prev) - + load_V( + src_idx=n_block_full_prev, producer_state=kv_producer_state_prev + ) + # Load last full block V n_block_full_last = curr_full_block_idx[0] pipeline_v.producer_acquire(kv_producer_state) @@ -1775,7 +2037,7 @@ def mma( full_block_idx: Optional[cute.Tensor], mask_block_cnt: Optional[cute.Tensor], mask_block_idx: Optional[cute.Tensor], - buffers: Optional[list[cute.Tensor]], + aux_tensors: Optional[list], fastdiv_mods=None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) @@ -1820,11 +2082,15 @@ def mma( mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) mma_one_n_block_all = partial( - self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, mma_qk_fn=mma_qk_fn, tiled_mma_pv_rs=tiled_mma_pv_rs, - pipeline_k=pipeline_k, pipeline_v=pipeline_v, - acc_O=acc_O, tOrP=tOrP, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, smem_copy_params=smem_copy_params, check_inf=True, ) @@ -1836,8 +2102,12 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) - + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + process_first_half_block = partial( self.first_half_block_overlap, mma_qk_fn=mma_qk_fn, @@ -1852,7 +2122,7 @@ def mma( mma_pv_fn=mma_pv_fn, ) while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: + # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1866,18 +2136,18 @@ def mma( thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, - buffers=buffers, + aux_tensors=aux_tensors, ) score_mod_fn = None if const_expr(self.score_mod is not None): score_mod_fn = partial( self.apply_score_mod, - thr_mma_qk=thr_mma_qk, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=m_block, + thr_mma_qk, + batch_idx, + head_idx, + m_block, softmax_scale=softmax_scale, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) mma_one_n_block = partial( @@ -1887,7 +2157,9 @@ def mma( ) # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): - pack_gqa = PackGQA(self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead + ) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) @@ -1906,10 +2178,9 @@ def mma( # We also need masking on S if it's causal, for the last several blocks. # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True O_should_accumulate = False - - + # ========================================== - # MAINLOOP + # MAINLOOP # ========================================== if const_expr(not self.use_block_sparsity): # ========================================== @@ -1921,6 +2192,7 @@ def mma( n_block=n_block_max - 1, kv_consumer_state=kv_consumer_state, mask_fn=mask_fn, + score_mod_fn=score_mod_fn, is_first_block=True, ) # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter @@ -1943,7 +2215,9 @@ def mma( seqlen, m_block, n_block_min ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, @@ -1984,7 +2258,7 @@ def mma( O_should_accumulate = True else: self.warp_scheduler_barrier_arrive() - + else: # ========================================== # Block sparsity @@ -2069,6 +2343,7 @@ def mma( n_block=mask_n_block, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, is_first_block=True, ) @@ -2091,6 +2366,7 @@ def mma( n_block=full_n_block, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None), + score_mod_fn=score_mod_fn, is_first_block=True, ) @@ -2124,8 +2400,7 @@ def mma( if curr_mask_block_cnt + curr_full_block_cnt == 0: softmax.reset() - acc_O.fill(0.0) - + acc_O.fill(0.0) sink_val = None if const_expr(learnable_sink is not None): @@ -2148,8 +2423,19 @@ def mma( # Epilogue # /////////////////////////////////////////////////////////////////////////////// self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, - gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, ) tile_scheduler.advance_to_next_work() @@ -2177,7 +2463,7 @@ def first_half_block_overlap( # Apply score modification if present if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S=acc_S, n_block=n_block) + score_mod_fn(acc_S, n_block=n_block) # Apply mask; mask_seqlen always True for first block # Caveat: if full block further right than mask block, seqlen masking is redundant; @@ -2203,7 +2489,7 @@ def first_half_block_overlap( cute.arch.sync_warp() return kv_consumer_state - + @cute.jit def last_half_block_overlap( self, @@ -2213,14 +2499,14 @@ def last_half_block_overlap( zero_init: bool, ): """Processes the final PV GEMM when using intra-warpgroup-overlap""" - + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) pipeline_v.consumer_release(kv_consumer_state) - + # Advance state for next iteration kv_consumer_state.advance() - + return kv_consumer_state @cute.jit @@ -2248,17 +2534,19 @@ def mma_one_n_block( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) - + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) if const_expr(mask_fn is not None): - mask_fn(acc_S, n_block=n_block) - + mask_fn(acc_S=acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) # tOrP.store(tOrP_acc.load().to(self.dtype)) # the "to(self.dtype)" conversion fails to vectorize for block sizes other # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of @@ -2310,19 +2598,21 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) - + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) - if const_expr(mask_fn is not None): - mask_fn(acc_S, n_block=n_block) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) # the "to(self.dtype)" conversion fails to vectorize for block sizes other # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of @@ -2358,7 +2648,7 @@ def apply_score_mod( acc_S, n_block, softmax_scale, - buffers=Optional[list[cute.Tensor]], + aux_tensors: Optional[list] = None, fastdiv_mods=None, ): # Prepare index tensor @@ -2375,7 +2665,7 @@ def apply_score_mod( softmax_scale, self.vec_size, self.qk_acc_dtype, - buffers, + aux_tensors, fastdiv_mods, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -2384,8 +2674,10 @@ def apply_score_mod( def warp_scheduler_barrier_sync(self): if const_expr(self.use_scheduler_barrier): cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - 1 + utils.canonical_warp_group_idx(sync=False), - number_of_threads=2 * self.num_threads_per_warp_group + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, ) def warp_scheduler_barrier_arrive(self): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 7bf1480bbae..83755896d51 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -37,7 +37,14 @@ from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.fast_math import FastDivmod -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + StaticPersistentTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) # class NamedBarrierFwd(enum.IntEnum): @@ -50,7 +57,6 @@ class FlashAttentionForwardSm100: - arch = 100 def __init__( @@ -66,7 +72,7 @@ def __init__( n_block_size: int = 128, is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, - has_buffers: cutlass.Constexpr = False, + has_aux_tensors: cutlass.Constexpr = False, ): # self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -96,9 +102,11 @@ def __init__( self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = pack_gqa if pack_gqa: - assert m_block_size % self.qhead_per_kvhead == 0, "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + assert m_block_size % self.qhead_per_kvhead == 0, ( + "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + ) self.score_mod = score_mod - if cutlass.const_expr(has_buffers): + if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 2 @@ -133,11 +141,16 @@ def __init__( ) self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 - self.tmem_o_offset = [self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage)] # e.g., 256, 384 + self.tmem_o_offset = [ + self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded + for i in range(self.q_stage) + ] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS self.tmem_s_to_p_offset = self.n_block_size // 2 - self.tmem_p_offset = [self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)] # 0, 128 + self.tmem_p_offset = [ + self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) + ] # 0, 128 # vec buffer for row_max & row_sum self.tmem_vec_offset = self.tmem_s_offset @@ -182,8 +195,14 @@ def _setup_attributes(self): # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. - self.uneven_kv_smem = self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 - self.uneven_kv_smem_offset = self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 if self.uneven_kv_smem else 0 + self.uneven_kv_smem = ( + self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 + ) + self.uneven_kv_smem_offset = ( + self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 + if self.uneven_kv_smem + else 0 + ) assert self.uneven_kv_smem_offset % 1024 == 0 @cute.jit @@ -204,7 +223,9 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - buffers = None # Not typing for now since conversion behaves a lil funny + aux_tensors: Optional[ + list + ] = None, # Not typing for now since conversion behaves a lil funny ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -226,8 +247,14 @@ def __call__( self.v_dtype = mV.element_type self.o_dtype = mO.element_type # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) @@ -240,7 +267,11 @@ def __call__( for t in (mK, mV) ] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if const_expr(mLSE is not None) + else None + ) # (s, d, h, b) -> (d, s, h, b) V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) @@ -266,7 +297,9 @@ def __call__( self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None # This can be tuned self.e2e_freq = 16 - if const_expr(self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa): + if const_expr( + self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa + ): self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 cta_group = tcgen05.CtaGroup.ONE @@ -300,39 +333,108 @@ def __call__( self.epi_tile = self.mma_tiler_pv[:2] sQ_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, + tiled_mma_qk, + self.mma_tiler_qk, + self.q_dtype, + self.q_stage, ) sK_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, + tiled_mma_qk, + self.mma_tiler_qk, + self.k_dtype, + self.kv_stage, ) tP_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.acc_stage, + tiled_mma_pv, + self.mma_tiler_pv, + self.q_dtype, + self.acc_stage, ) sV_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage, + tiled_mma_pv, + self.mma_tiler_pv, + self.v_dtype, + self.kv_stage, ) sO_layout = sm100_utils_basic.make_smem_layout_epi( - self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, ) if const_expr(not self.same_hdim_kv_padded): # sK and sV are using the same physical smem so we need to adjust the stride so that they line up - stride_sK = const_expr(max(sK_layout.outer.stride[-1], 0)) # take max to turn tuple to Int32 + stride_sK = const_expr( + max(sK_layout.outer.stride[-1], 0) + ) # take max to turn tuple to Int32 stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) - stage_stride = const_expr(max(stride_sK, stride_sV) if not self.uneven_kv_smem else (stride_sK + stride_sV) // 2) - sK_layout = cute.make_composed_layout(sK_layout.inner, 0, cute.make_layout((*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride))) - sV_layout = cute.make_composed_layout(sV_layout.inner, 0, cute.make_layout((*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride))) + stage_stride = const_expr( + max(stride_sK, stride_sV) + if not self.uneven_kv_smem + else (stride_sK + stride_sV) // 2 + ) + sK_layout = cute.make_composed_layout( + sK_layout.inner, + 0, + cute.make_layout( + (*sK_layout.outer.shape[:-1], self.kv_stage), + stride=(*sK_layout.outer.stride[:-1], stage_stride), + ), + ) + sV_layout = cute.make_composed_layout( + sV_layout.inner, + 0, + cute.make_layout( + (*sV_layout.outer.shape[:-1], self.kv_stage), + stride=(*sV_layout.outer.stride[:-1], stage_stride), + ), + ) if const_expr(self.pack_gqa): - shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) - stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) - mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) - shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) - stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) - mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) if const_expr(mLSE is not None): - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) @@ -386,11 +488,14 @@ def __call__( universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.o_dtype.width atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.o_dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.o_dtype, + num_bits_per_copy=universal_copy_bits, ) tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems tO_layout = cute.make_ordered_layout( - (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), order=(1, 0), + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O assert self.m_block_size % tO_layout.shape[0] == 0 @@ -412,15 +517,25 @@ def __call__( if const_expr(self.is_causal or self.is_local): TileScheduler = SingleTileLPTScheduler else: - TileScheduler = SingleTileScheduler if const_expr(not self.is_persistent) else StaticPersistentTileScheduler + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_persistent) + else StaticPersistentTileScheduler + ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), - cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], mQ.shape[1], mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 - total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=self.cta_tiler[:2], mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, @@ -493,8 +608,10 @@ class SharedStorage: window_size_right = Int32(window_size_right) fastdiv_mods = None - if cutlass.const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + if cutlass.const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -530,7 +647,7 @@ class SharedStorage: tiled_mma_qk, tiled_mma_pv, tile_sched_params, - buffers, + aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, @@ -573,8 +690,8 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, - buffers = None, - fastdiv_mods = (None, None), + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -609,28 +726,55 @@ def kernel( if warp_idx == 1: # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) + ) if warp_idx == 2: for i in cutlass.range_constexpr(2): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4 + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4 + ) if warp_idx == 3: if const_expr(self.s0_s1_barrier): for i in cutlass.range_constexpr(8): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE + ) if warp_idx == 4: for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_full_offset + i, + cute.arch.WARP_SIZE * len(self.correction_warp_ids), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_empty_offset + i, + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids), + ) if warp_idx == 5: for i in cutlass.range_constexpr(2): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, + cute.arch.WARP_SIZE + * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id]) + ) if warp_idx == 6: for i in cutlass.range_constexpr(2): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_2_offset + i, cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_2_offset + i, + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids), + ) if warp_idx == 7: cute.arch.mbarrier_init( mbar_ptr + self.mbar_tmem_dealloc_offset, @@ -668,43 +812,60 @@ def kernel( tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always # request 512 columns of tmem, so we know that it starts at 0. - tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, - assumed_align=16) + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) - tStSs = tuple(cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) - for stage in range(2)) - tOtOs = tuple(cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) - for stage in range(self.q_stage)) + tStSs = tuple( + cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(2) + ) + tOtOs = tuple( + cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage) + ) tP = cute.make_tensor(tStS.iterator, tP_layout.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] - tOrPs = [cute.make_tensor( - tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], - tOrP.layout, - ) for stage in range(2)] + tOrPs = [ + cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], + tOrP.layout, + ) + for stage in range(2) + ] block_info = BlockInfo( # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) - self.cta_tiler[0], self.cta_tiler[1], self.is_causal, self.is_local, - window_size_left, window_size_right, + self.cta_tiler[0], + self.cta_tiler[1], + self.is_causal, + self.is_local, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], - seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], - mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, - mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + seqlen_k_static=mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, ) AttentionMaskCls = partial( - AttentionMask, self.m_block_size, self.n_block_size, - window_size_left=window_size_left, window_size_right=window_size_right, + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) @@ -745,7 +906,7 @@ def kernel( # MMA # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: - # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) # Alloc tmem buffer tmem_alloc_cols = Int32(self.tmem_alloc_cols) @@ -787,7 +948,9 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - self.epilogue_s2g(mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls) + self.epilogue_s2g( + mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls + ) # /////////////////////////////////////////////////////////////////////////////// # Softmax @@ -808,7 +971,7 @@ def kernel( SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, TileSchedulerCls=TileSchedulerCls, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -817,8 +980,9 @@ def kernel( softmax_loop( stage=stage, tStSi=cute.make_tensor( - tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), - tStS.layout + tStS.iterator + + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), + tStS.layout, ), ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) @@ -880,7 +1044,6 @@ def load( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - q_producer_phase = Int32(1) kv_producer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.kv_stage @@ -893,7 +1056,9 @@ def load( mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) - head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) if const_expr(mPageTable is None): if const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] @@ -905,8 +1070,12 @@ def load( else: # Need to keep batch coord None since we'll index into it with page idx mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] - gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)) - gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)) + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None) + ) + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) + ) tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) @@ -929,26 +1098,40 @@ def load( ) load_Q = partial( - self.load_Q, load_Q_fn, - mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, + self.load_Q, + load_Q_fn, + mbar_ptr + self.mbar_load_q_full_offset, + mbar_ptr + self.mbar_load_q_empty_offset, phase=q_producer_phase, ) # We have to use mbarrier directly in the load for KV instead of replying on # pipeline_kv, because we could have different number of TMA bytes for K and V load_K = partial( - self.load_KV, tma_atom_K, tKgK, tKsK, - mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + self.load_KV, + tma_atom_K, + tKgK, + tKsK, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="K", ) load_V = partial( - self.load_KV, tma_atom_V, tVgV, tVsV, - mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + self.load_KV, + tma_atom_V, + tVgV, + tVsV, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="V", ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 - page_idx = mPageTable[batch_idx, n_block_max - 1] if const_expr(mPageTable is not None) else None + page_idx = ( + mPageTable[batch_idx, n_block_max - 1] + if const_expr(mPageTable is not None) + else None + ) load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() if const_expr(self.q_stage == 2): @@ -958,7 +1141,9 @@ def load( kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i - page_idx = mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + page_idx = ( + mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + ) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki kv_producer_state.advance() @@ -1005,7 +1190,7 @@ def mma( self.tmem_s_offset[stage], tSrQs[stage], sA=sQ[None, None, None, stage], - zero_init=True + zero_init=True, ) for stage in range(2) ] @@ -1036,7 +1221,9 @@ def mma( for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase + ) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) @@ -1049,7 +1236,9 @@ def mma( # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) sK_cur = sK[None, None, None, mma_kv_consumer_state.index] if const_expr(self.uneven_kv_smem): - sK_cur = self.offset_kv_smem(sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase) + sK_cur = self.offset_kv_smem( + sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase + ) gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) # 4. release S0 / S1 with cute.arch.elect_one(): @@ -1078,7 +1267,7 @@ def mma( # the last iteration of the previous work tile has finished. cute.arch.mbarrier_wait( mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, - P_full_O_rescaled_phase + P_full_O_rescaled_phase, ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) @@ -1091,7 +1280,7 @@ def mma( sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase + mbar_phase=P_full_O_rescaled_phase, ) # 4. release accumulated O0_partial / O1_partial # Don't need to signal O_full to the correction warps anymore since the @@ -1145,8 +1334,7 @@ def mma( for stage in cutlass.range_constexpr(2): # 2. acquire corrected Oi_partial and Pi cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, - P_full_O_rescaled_phase + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) @@ -1159,7 +1347,7 @@ def mma( sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase + mbar_phase=P_full_O_rescaled_phase, ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp @@ -1197,8 +1385,8 @@ def softmax_loop( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, - buffers = None, - fastdiv_mods = (None, None) + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1214,8 +1402,7 @@ def softmax_loop( tidx = cute.arch.thread_idx()[0] % ( cute.arch.WARP_SIZE # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) - * (len(self.softmax0_warp_ids) - ) + * (len(self.softmax0_warp_ids)) ) tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1))) @@ -1223,23 +1410,30 @@ def softmax_loop( tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width - tStP_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP_layout = cute.composition( + tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + ) tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + Float32, ) thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) tStS_t2r = thr_tmem_load.partition_S(tStSi) tmem_store_scale_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), + Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice( + tidx ) - thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(tidx) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), + Float32, ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) @@ -1266,9 +1460,13 @@ def softmax_loop( thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, - mask_local=self.is_local + mask_local=self.is_local, + ) + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, ) - softmax = SoftmaxSm100.create(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale) softmax.reset() softmax_step = partial( @@ -1289,15 +1487,24 @@ def softmax_loop( head_idx=head_idx, m_block=self.q_stage * m_block + stage, seqlen=seqlen, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) si_corr_producer_phase ^= 1 # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) n_block_max -= 1 # Next couple of iterations with causal masking if const_expr(self.is_causal or self.is_local): @@ -1306,7 +1513,15 @@ def softmax_loop( ) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( @@ -1314,13 +1529,23 @@ def softmax_loop( ) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block + ) # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) # Now that we no longer already have the 1st iteration, need mask_seqlen=True here # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape @@ -1330,7 +1555,9 @@ def softmax_loop( # cute.arch.fence_view_async_tmem_store() sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] if const_expr(mLSE is not None or learnable_sink is not None): - sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[0] + sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[ + 0 + ] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) @@ -1383,8 +1610,8 @@ def softmax_step( head_idx: Int32, m_block: Int32, seqlen, - buffers = None, - fastdiv_mods = (None, None), + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), mask_fn: Optional[Callable] = None, is_first: bool = False, ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: @@ -1422,8 +1649,8 @@ def softmax_step( m_block, n_block, softmax, - buffers, - fastdiv_mods + aux_tensors, + fastdiv_mods, ) if const_expr(mask_fn is not None): @@ -1446,14 +1673,21 @@ def softmax_step( softmax.scale_subtract_rowmax(tSrS_t2r, row_max) # Sequence barrier wait if const_expr(self.s0_s1_barrier): - cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) + cute.arch.mbarrier_wait( + mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase + ) tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) tSrP_r2t = cute.make_tensor( - cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and self.head_dim_padded <= 128, - e2e_freq=self.e2e_freq) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=self.e2e_freq, + ) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) @@ -1464,12 +1698,16 @@ def softmax_step( cute.arch.fence_view_async_tmem_store() # Notify mma warp that P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])): + for i in cutlass.range_constexpr( + cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2]) + ): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) cute.arch.fence_view_async_tmem_store() # Notify mma warp that the 2nd half of P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.arch.exp2(acc_scale_) return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 @@ -1496,11 +1734,14 @@ def correction_loop( tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) - tStScales = tuple(cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) - for stage in range(2)) + tStScales = tuple( + cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) + for stage in range(2) + ) tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tmem_load_v_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), + self.qk_acc_dtype, ) thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) @@ -1523,16 +1764,23 @@ def correction_loop( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # Ignore first signal from softmax as no correction is required - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase + ) softmax_corr_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): for stage in cutlass.range_constexpr(2): # wait for S0 / S1 - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] @@ -1548,7 +1796,9 @@ def correction_loop( thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + ) softmax_corr_consumer_phase ^= 1 # o_corr_consumer_phase ^= 1 # End of seqlen_corr_loop_steps @@ -1566,10 +1816,15 @@ def correction_loop( learnable_sink_val = [sink_val] * self.q_stage else: # Each thread might have a different sink value due to different q_head for stage in cutlass.range_constexpr(self.q_stage): - q_head_idx = ((self.q_stage * m_block + stage) * self.m_block_size + tidx) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + q_head_idx = ( + (self.q_stage * m_block + stage) * self.m_block_size + tidx + ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) for stage in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] @@ -1581,14 +1836,24 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) if const_expr(learnable_sink is not None): LOG2_E = math.log2(math.e) - row_sum += utils.exp2f(learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2) + row_sum += utils.exp2f( + learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2 + ) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase + ) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase + ) self.correction_epilogue( - thr_mma_pv, tOtOs[stage], tidx, scale, sO[None, None, stage], + thr_mma_pv, + tOtOs[stage], + tidx, + scale, + sO[None, None, stage], ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) # Signal for the next work tile that O buffers in tmem are already read, so @@ -1599,19 +1864,28 @@ def correction_loop( if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)) + gLSE = cute.local_tile( + mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,) + ) row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) LN2 = math.log(2.0) lse = ( (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 - if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead ) - seqlen_q = seqlen.seqlen_q if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: # This actually just works with PackGQA too gLSE[tidx] = lse @@ -1693,7 +1967,8 @@ def correction_rescale( cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( - (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), ) tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) @@ -1748,7 +2023,9 @@ def correction_epilogue( epi_subtile, use_2cta_instrs=False, ) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice(tidx) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice( + tidx + ) thr_tmem_load = tiled_tmem_load.get_slice(tidx) smem_copy_atom = sm100_utils_basic.get_smem_store_op( self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load @@ -1765,14 +2042,16 @@ def correction_epilogue( cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( - (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), ) tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype) tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype)) cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, ) @cute.jit @@ -1812,7 +2091,9 @@ def epilogue_s2g( cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) else: - tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + ) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) @@ -1822,11 +2103,18 @@ def epilogue_s2g( tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it assert not self.pack_gqa - pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) # 2. copy O0 / O1 to gmem # load acc O from smem to rmem for wider vectorization tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) @@ -1834,15 +2122,29 @@ def epilogue_s2g( # copy acc O from rmem to gmem if const_expr(not self.pack_gqa): for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size - tOcO[0][0]: + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None, self.q_stage * m_block + stage], - pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + pred=tOpO[None, rest_m, None] + if self.check_hdim_v_oob + else None, ) else: - pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, self.q_stage * m_block + stage, seqlen.seqlen_q) + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen.seqlen_q, + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) # Advance to next tile @@ -1886,7 +2188,9 @@ def load_KV( if stage == 0: cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V]) + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V] + ) tXsX_cur = tXsX[None, stage] if const_expr(self.uneven_kv_smem): # Since this is the producer_state, the phase starts at 1, so we have to invert it @@ -1907,9 +2211,12 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): return sX def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - load_kv_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) return cutlass.pipeline.PipelineTmaUmma.create( barrier_storage=load_kv_mbar_ptr, num_stages=self.kv_stage, @@ -1950,7 +2257,7 @@ def apply_score_mod( m_block, n_block, softmax, - buffers=None, + aux_tensors=None, fastdiv_mods=(None, None), ): """Apply score modification for SM100 (constant q_idx).""" @@ -1971,7 +2278,7 @@ def apply_score_mod( head_offset = q_physical - q_idx_logical * self.qhead_per_kvhead head_idx = head_idx * self.qhead_per_kvhead + head_offset - if cutlass.const_expr(buffers is not None): + if cutlass.const_expr(aux_tensors is not None): seqlen_q_divmod, _ = fastdiv_mods _, q_idx_logical = seqlen_q_divmod.divmod(q_idx_logical) @@ -1984,7 +2291,7 @@ def apply_score_mod( softmax.softmax_scale, self.vec_size, self.qk_acc_dtype, - buffers, + aux_tensors, fastdiv_mods, constant_q_idx=q_idx_logical, qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8c2e5903fc4..e3d2eb0891b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype @@ -51,6 +52,7 @@ def maybe_contiguous(x): torch.float32: cutlass.Float32, } + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -83,7 +85,7 @@ def _flash_attn_fwd( return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, - buffers: Optional[list[torch.Tensor]] = None, + aux_tensors: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -93,7 +95,7 @@ def _flash_attn_fwd( return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. - buffers: Some score_mods will want to read from global buffers. This is how we thread them through to the inner kernel. + aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] @@ -127,34 +129,52 @@ def _flash_attn_fwd( else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) - assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" - assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)" - assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)" + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) + assert seqused_q is None or seqused_q.shape == (batch_size,), ( + "seqused_q must have shape (batch_size,)" + ) + assert seqused_k is None or seqused_k.shape == (batch_size,), ( + "seqused_k must have shape (batch_size,)" + ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: - assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" - assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + assert t.dtype == torch.int32, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" + ) + assert t.stride(0) == 1, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + ) if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: if t is not None: assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" - assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" + # assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" assert all( t is None or t.is_cuda for t in ( - q, k, v, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, page_table, learnable_sink, - full_block_cnt, full_block_idx, - mask_block_cnt, mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, ) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" @@ -177,20 +197,38 @@ def _flash_attn_fwd( requires_grad = q.requires_grad or k.requires_grad or v.requires_grad if out is None: - out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) + out = torch.empty( + *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device + ) else: expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v) - assert out.shape == expected_out_shape, f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" - assert out.dtype == out_torch_dtype, f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" - assert out.device == device, f"out tensor device {out.device} does not match input device {device}" + assert out.shape == expected_out_shape, ( + f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" + ) + assert out.dtype == out_torch_dtype, ( + f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" + ) + assert out.device == device, ( + f"out tensor device {out.device} does not match input device {device}" + ) assert out.is_cuda, "out tensor must be on CUDA device" if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad or return_lse else None + lse = ( + torch.empty(lse_shape, dtype=torch.float32, device=device) + if requires_grad or return_lse + else None + ) elif lse is not None: - assert lse.shape == lse_shape, f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" - assert lse.dtype == torch.float32, f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" - assert lse.device == device, f"lse tensor device {lse.device} does not match input device {device}" + assert lse.shape == lse_shape, ( + f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" + ) + assert lse.dtype == torch.float32, ( + f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" + ) + assert lse.device == device, ( + f"lse tensor device {lse.device} does not match input device {device}" + ) assert lse.is_cuda, "lse tensor must be on CUDA device" dtype = torch2cute_dtype_map[q.dtype] @@ -198,82 +236,156 @@ def _flash_attn_fwd( from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (q, k, v, out) ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + lse_tensor = ( + from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if lse is not None + else None + ) + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + learnable_sink_tensor, + ) = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if t is not None + else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] - page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None - - full_block_cnt_tensor = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if full_block_cnt is not None else None - full_block_idx_tensor = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if full_block_idx is not None else None - mask_block_cnt_tensor = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if mask_block_cnt is not None else None - mask_block_idx_tensor = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if mask_block_idx is not None else None - - - if causal: - window_size_right = 0 - local = window_size_left is not None or window_size_right is not None - if window_size_left is not None or window_size_right is not None: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - else: - causal, local = False, True - compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability + page_table_tensor = ( + from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) + if page_table is not None + else None + ) + + full_block_cnt_tensor = ( + from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + if full_block_cnt is not None + else None + ) + full_block_idx_tensor = ( + from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) + if full_block_idx is not None + else None + ) + mask_block_cnt_tensor = ( + from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + if mask_block_cnt is not None + else None + ) + mask_block_idx_tensor = ( + from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) + if mask_block_idx is not None + else None + ) + use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None + + if mask_mod is None: + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + else: + causal, local = False, True + else: + causal, local = False, False + + compute_capability = ( + torch.cuda.get_device_capability()[0] + if _compute_capability is None + else _compute_capability + ) assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - if compute_capability == 9: # TODO: tune block size according to hdim - if head_dim == head_dim_v == 128 and not causal and not local: + if compute_capability == 9: # TODO: tune block size according to hdim. + if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: n_block_size = 192 if compute_capability == 10: # TODO: fix the varlen case - if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): + if ( + pack_gqa + and (128 % qhead_per_kvhead != 0) + or (cu_seqlens_q is not None or seqused_q is not None) + ): pack_gqa = False - + # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else None - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else None - + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False + + print(mask_mod_hash) + if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) - is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None - use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) if score_mod is not None: if is_varlen: - raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") - if pack_gqa: - raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + raise NotImplementedError( + "score_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) if mask_mod is not None: if not use_block_sparsity: - raise NotImplementedError("mask_mod requires the use of block sparsity. This will be fixed in a future PR.") + raise NotImplementedError( + "mask_mod requires the use of block sparsity. This will be fixed in a future PR." + ) if is_varlen: - raise NotImplementedError("mask_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) if pack_gqa: - raise NotImplementedError("mask_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") - + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported with pack_gqa=True. This will be fixed in a future PR." + ) + if use_block_sparsity: if is_varlen: - raise NotImplementedError("Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.") + raise NotImplementedError( + "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." + ) if pack_gqa: - raise NotImplementedError("Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR.") - - cute_buffers = None - if buffers is not None: - cute_buffers = [from_dlpack(buf) for buf in buffers] + raise NotImplementedError( + "Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR." + ) + + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [from_dlpack(buf) for buf in aux_tensors] compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, - score_mod_hash, mask_mod_hash, - buffers is not None, - lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + score_mod_hash, + mask_mod_hash, + use_block_sparsity, + aux_tensors is not None, + lse is None, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, page_table is not None, - window_size_left is not None, window_size_right is not None, + window_size_left is not None, + window_size_right is not None, learnable_sink is not None, - m_block_size, n_block_size, num_threads, pack_gqa, + m_block_size, + n_block_size, + num_threads, + pack_gqa, compute_capability, ) @@ -299,10 +411,12 @@ def _flash_attn_fwd( mma_pv_is_rs=True, mask_mod=mask_mod, score_mod=score_mod, - has_buffers=buffers is not None, + has_aux_tensors=aux_tensors is not None, ) elif compute_capability == 10: - assert page_size in [None, 128], "Only page_size=128 is supported for paged KV on SM 10.0" + assert page_size in [None, 128], ( + "Only page_size=128 is supported for paged KV on SM 10.0" + ) fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -310,34 +424,69 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, pack_gqa=pack_gqa, - is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None, score_mod=score_mod, - has_buffers=buffers is not None, + has_aux_tensors=aux_tensors is not None, ) else: - raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") + raise ValueError( + f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x" + ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, - buffers=cute_buffers, + window_size_left, + window_size_right, + learnable_sink_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + mask_block_cnt_tensor, + mask_block_idx_tensor, + cute_aux_tensors, ) _flash_attn_fwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, - buffers=cute_buffers, + window_size_left, + window_size_right, + learnable_sink_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + mask_block_cnt_tensor, + mask_block_idx_tensor, + cute_aux_tensors, ) return out, lse _flash_attn_fwd.compile_cache = {} + def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, @@ -407,10 +556,14 @@ def _flash_attn_bwd( else: assert k.shape == (total_k, num_head_kv, head_dim) assert v.shape == (total_k, num_head_kv, head_dim_v) - assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) assert out.shape == (total_q, num_head, head_dim_v) assert dout.shape == (total_q, num_head, head_dim_v) @@ -418,15 +571,21 @@ def _flash_attn_bwd( else: assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) - assert lse.shape == (batch_size, num_head, seqlen_q), "lse must have shape (batch_size, num_head, seqlen_q)" + assert lse.shape == (batch_size, num_head, seqlen_q), ( + "lse must have shape (batch_size, num_head, seqlen_q)" + ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" - assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, "inputs must have the same dtype" + assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( + "inputs must have the same dtype" + ) for t in [cu_seqlens_q, cu_seqlens_k]: if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" - assert all(t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)), "inputs must be on CUDA device" + assert all( + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -448,12 +607,26 @@ def _flash_attn_bwd( if cu_seqlens_q is None: seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size - dq_accum = torch.empty(batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dpsum = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) - lse_log2 = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + dq_accum = torch.empty( + batch_size, + num_head, + seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dpsum = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) + lse_log2 = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) else: - total_q_rounded_padded = (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size - dq_accum = torch.empty(num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device) + total_q_rounded_padded = ( + (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size + ) + dq_accum = torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) @@ -461,19 +634,45 @@ def _flash_attn_bwd( head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size - dk_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dv_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device) + dk_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) else: - total_k_rounded_padded = (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size - dk_accum = torch.zeros(num_head_kv, total_k_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device) - dv_accum = torch.zeros(num_head_kv, total_k_rounded_padded * head_dim_v_rounded, dtype=torch.float32, device=device) + total_k_rounded_padded = ( + (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size + ) + dk_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (q, k, v, out, dout, dq, dk, dv) ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=lse.ndim - 1 + ) dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dq_accum, dpsum, lse_log2) @@ -484,7 +683,9 @@ def _flash_attn_bwd( for t in (dk_accum, dv_accum) ] cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim-1) if t is not None else None + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim - 1) + if t is not None + else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -493,23 +694,57 @@ def _flash_attn_bwd( compile_key_pre = (dtype, head_dim_v, m_block_size, num_threads) if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: fa_bwd_pre = FlashAttentionBackwardPreprocess( - dtype, head_dim_v, m_block_size, num_threads=num_threads, + dtype, + head_dim_v, + m_block_size, + num_threads=num_threads, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( - fa_bwd_pre, o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, - dq_accum_tensor, cu_seqlens_q_tensor, seqused_q_tensor, current_stream + fa_bwd_pre, + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, - cu_seqlens_q_tensor, seqused_q_tensor, current_stream + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) # Backward kernel: compute dk, dv, dq_accum. compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, m_block_size, - n_block_size, num_threads, pack_gqa, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, - AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + num_stages_Q, + num_stages_dO, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + V_in_regs, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -557,7 +792,12 @@ def _flash_attn_bwd( _flash_attn_bwd.compile_cache[compile_key] = cute.compile( # fa_bwd_sm80, fa_bwd_sm90, - q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, @@ -569,7 +809,12 @@ def _flash_attn_bwd( seqused_k_tensor, ) _flash_attn_bwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, @@ -591,11 +836,21 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dq_accum_tensor, dq_tensor, softmax_scale, cu_seqlens_q_tensor, - seqused_q_tensor, current_stream + fa_bwd_post, + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum_tensor, dq_tensor, softmax_scale, cu_seqlens_q_tensor, seqused_q_tensor, current_stream + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) if qhead_per_kvhead > 1: @@ -607,22 +862,51 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dk_accum_tensor, dk_tensor, softmax_scale, cu_seqlens_k_tensor, seqused_k_tensor, current_stream + fa_bwd_post, + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum_tensor, dk_tensor, softmax_scale, cu_seqlens_k_tensor, seqused_k_tensor, current_stream + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, + ) + compile_key_post = ( + dtype, + head_dim_v, + n_block_size, + num_threads, + AtomLayoutNdKV, + dKV_swapAB, ) - compile_key_post = (dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), cu_seqlens_k_tensor, seqused_k_tensor, current_stream + fa_bwd_post, + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), cu_seqlens_k_tensor, seqused_k_tensor, current_stream + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) return dq, dk, dv @@ -634,7 +918,6 @@ def _flash_attn_bwd( class FlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -695,7 +978,6 @@ def backward(ctx, dout, *args): class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -864,7 +1146,9 @@ def _flash_attn_fwd_combine( # Input validation assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], "out_partial must be fp16, bf16, or fp32" + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + "out_partial must be fp16, bf16, or fp32" + ) assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" @@ -881,7 +1165,11 @@ def _flash_attn_fwd_combine( assert lse.dtype == torch.float32, "lse must be fp32" # Validate optional tensors - for t, name in [(cu_seqlens, "cu_seqlens"), (seqused, "seqused"), (num_splits_dynamic_ptr, "num_splits_dynamic_ptr")]: + for t, name in [ + (cu_seqlens, "cu_seqlens"), + (seqused, "seqused"), + (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), + ]: if t is not None: assert t.dtype == torch.int32, f"{name} must be int32" assert t.is_cuda, f"{name} must be on CUDA device" @@ -903,16 +1191,28 @@ def _flash_attn_fwd_combine( log_max_splits = max(log_max_splits, 5) # Convert to cute tensors (using kernel-formatted tensors) - out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=4) - lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 2) + out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=4 + ) + lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=lse_partial.ndim - 2 + ) out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3) - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) if lse is not None else None + lse_tensor = ( + from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) + if lse is not None + else None + ) optional_tensors = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if t is not None + else None for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) ] - cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = optional_tensors + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( + optional_tensors + ) current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -921,9 +1221,15 @@ def _flash_attn_fwd_combine( dtype_partial = torch2cute_dtype_map[out_partial.dtype] compile_key = ( - dtype, dtype_partial, head_dim, m_block_size, k_block_size, + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, log_max_splits, - cu_seqlens is not None, seqused is not None, lse is not None, + cu_seqlens is not None, + seqused is not None, + lse is not None, ) if compile_key not in _flash_attn_fwd_combine.compile_cache: @@ -938,9 +1244,17 @@ def _flash_attn_fwd_combine( # Check if implementation is supported if not fa_combine.can_implement( - dtype, dtype_partial, head_dim, m_block_size, k_block_size, log_max_splits, num_threads=256 + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads=256, ): - raise RuntimeError(f"FlashAttention combine kernel cannot be implemented with given parameters") + raise RuntimeError( + f"FlashAttention combine kernel cannot be implemented with given parameters" + ) _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( fa_combine, @@ -952,7 +1266,7 @@ def _flash_attn_fwd_combine( seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor, - current_stream + current_stream, ) _flash_attn_fwd_combine.compile_cache[compile_key]( @@ -964,7 +1278,7 @@ def _flash_attn_fwd_combine( seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor, - current_stream + current_stream, ) @@ -1019,13 +1333,17 @@ def flash_attn_combine( if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, total_q, num_heads), "lse_partial shape mismatch for varlen" + assert lse_partial.shape == (num_splits, total_q, num_heads), ( + "lse_partial shape mismatch for varlen" + ) batch_size = 1 # Treat as single batch for varlen seqlen = total_q else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), "lse_partial shape mismatch" + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( + "lse_partial shape mismatch" + ) # Determine output dtype if out_dtype is None: @@ -1037,14 +1355,20 @@ def flash_attn_combine( if is_varlen: out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) else: - out = torch.empty(batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device) + out = torch.empty( + batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device + ) # Create lse output only if requested if return_lse: if is_varlen: - lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(0, 1) + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( + 0, 1 + ) else: - lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device).transpose(1, 2) + lse = torch.empty( + batch_size, num_heads, seqlen, dtype=torch.float32, device=device + ).transpose(1, 2) else: lse = None diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 0d78eb9e948..7b830f42c4e 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -9,6 +9,7 @@ import flash_attn.cute.utils as utils + @cute.jit def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: # Bit manipulation, compiles down to the R2P instruction @@ -38,6 +39,7 @@ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = Fal for r in cutlass.range_constexpr(cute.size(X.shape[0])): X[r, c] = X[r, c] if in_bound else -Float32.inf + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -62,7 +64,7 @@ def apply_mask( mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, mask_mod: cutlass.Constexpr[Optional[Callable]] = None, - buffers: Optional[list[cute.Tensor]] = None, + aux_tensors: Optional[list] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) @@ -90,20 +92,22 @@ def apply_mask( acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) - - elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # FlexAttention mask mod + + elif const_expr( + not mask_causal and not mask_local and mask_mod is not None + ): # FlexAttention mask mod nrow = const_expr(cute.size(tScS_mn.shape[0])) ncol = const_expr(cute.size(tScS_mn.shape[1])) thr_col_offset = tScS_mn[0, 0][1] - + for r in cutlass.range_constexpr(nrow): global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m - + for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][1] # Convert to absolute column index global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n - + cond = cutlass.Boolean( mask_mod( batch_idx, @@ -112,7 +116,7 @@ def apply_mask( thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, self.seqlen_q, self.seqlen_k, - buffers, + aux_tensors, ) ) if const_expr(mask_seqlen): @@ -126,7 +130,6 @@ def apply_mask( else: acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf - else: # Causal or local if const_expr(not self.swap_AB): # If PackGQA, we split the work of compute divmod among threads in the same row @@ -321,12 +324,11 @@ def apply_mask_sm100( else acc_S[i] ) - @cute.jit def apply_mask_sm100_transposed( self, acc_S: cute.Tensor, - tScS_t2r : cute.Tensor, + tScS_t2r: cute.Tensor, m_block: cutlass.Int32, n_block: cutlass.Int32, wg_idx: cutlass.Int32, @@ -335,9 +337,9 @@ def apply_mask_sm100_transposed( mask_causal: cutlass.Constexpr, mask_local: cutlass.Constexpr, ) -> None: - ''' + """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. - ''' + """ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" tidx = cute.arch.thread_idx()[0] % 128 @@ -352,7 +354,7 @@ def apply_mask_sm100_transposed( else: # Causal or local causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m row_idx = tScS_t2r[0][0] + n_block * self.tile_n - + if const_expr(mask_causal): col_limit_left = row_idx + causal_row_offset ncol = const_expr(cute.size(tScS_t2r.shape)) @@ -365,4 +367,4 @@ def apply_mask_sm100_transposed( acc_S[i] = ( -cutlass.Float32.inf if tScS_t2r[i][1] <= col_limit_left else acc_S[i] ) - # TODO: local \ No newline at end of file + # TODO: local diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index 6b206fd6026..23c4f026b1c 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -1,7 +1,7 @@ from typing import Callable, Optional import random -import math +import math import cutlass import cutlass.cute as cute @@ -10,7 +10,14 @@ MaskModCallable = Optional[ Callable[ - ["cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32"], + [ + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + ], "cutlass.Boolean", ] ] @@ -49,12 +56,14 @@ def flex_block_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): def create_flex_sliding_window_mask(window_size=1024): """Factory function to create a sliding window mask with configurable window size""" + def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): # Sliding window: q_idx - window_size <= kv_idx <= q_idx if seqlen_q is not None and seqlen_k is not None: offset = seqlen_k - seqlen_q return (kv_idx <= q_idx + offset) & (kv_idx >= q_idx + offset - window_size) return (kv_idx <= q_idx) & (kv_idx >= q_idx - window_size) + return flex_sliding_window_mask @@ -83,32 +92,49 @@ def flex_half_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): return torch.ones_like(kv_idx, dtype=torch.bool) return True + def flex_document_mask(b, h, q_idx, kv_idx, doc_id: torch.Tensor): return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] + # CuTe versions for kernel compilation @cute.jit def cute_identity_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: return cutlass.Boolean(True) @cute.jit def cute_identity_partial_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: return cutlass.Boolean(True) @cute.jit def cute_causal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: # Right-aligned causal masking offset = seqlen_k - seqlen_q @@ -117,8 +143,13 @@ def cute_causal_mask( @cute.jit def cute_block_causal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: # Right-aligned causal masking offset = seqlen_k - seqlen_q @@ -127,22 +158,36 @@ def cute_block_causal_mask( def create_cute_sliding_window_mask(window_size=1024): """Factory function to create a CuTe sliding window mask with configurable window size""" + @cute.jit def cute_sliding_window_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: offset = seqlen_k - seqlen_q - return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + return cutlass.Boolean( + (n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size) + ) + return cute_sliding_window_mask # Default sliding window mask with window_size=1024 for backward compatibility @cute.jit def cute_sliding_window_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: window_size = 1024 # offset = seqlen_k - seqlen_q @@ -152,24 +197,40 @@ def cute_sliding_window_mask( @cute.jit def cute_document_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: list, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: list, ): - doc_id = buffers[0] + doc_id = aux_tensors[0] return cutlass.Boolean(doc_id[batch, head, m_idx] == doc_id[batch, head, n_idx]) - + @cute.jit def cute_block_diagonal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: return cutlass.Boolean((m_idx // 64) == (n_idx // 64)) @cute.jit def cute_mini_causal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: """Each tile is locally causal-masked""" m_mod = m_idx % 128 @@ -179,8 +240,12 @@ def cute_mini_causal_mask( @cute.jit def cute_half_identity_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32 + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, ) -> cutlass.Boolean: return cutlass.Boolean(True) @@ -191,17 +256,17 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): for h in range(nheads): N = seqlen_q n = random.randint(1, math.ceil(math.sqrt(N // 4))) - cuts = sorted(random.sample(range(1, N), n-1)) + cuts = sorted(random.sample(range(1, N), n - 1)) lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] doc_ids = [] for i, length in enumerate(lengths): doc_ids += [i for _ in range(length)] - + doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) print(f"{doc_ids_tensor.shape = }") return doc_ids_tensor - + MASK_FUNCTIONS = { "identity": (cute_identity_mask, flex_identity_mask), @@ -217,4 +282,4 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): if __name__ == "__main__": doc_ids = random_doc_id_tensor(1, 2, 128) - print(f"{doc_ids = }") \ No newline at end of file + print(f"{doc_ids = }") diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 72de115732a..0ca08f3f2e3 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -337,7 +337,7 @@ def apply_score_mod_inner( softmax_scale, vec_size: cutlass.Constexpr, qk_acc_dtype: cutlass.Constexpr, - buffers, + aux_tensors, fastdiv_mods, constant_q_idx: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, @@ -353,7 +353,7 @@ def apply_score_mod_inner( softmax_scale: Scale to apply vec_size: Vector size for processing elements qk_acc_dtype: Data type for accumulator - buffers: Optional buffers for FlexAttention + aux_tensors: Optional aux_tensors for FlexAttention fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping constant_q_idx: If provided, use this constant for all q_idx values If None, compute q_idx per-element @@ -388,7 +388,7 @@ def apply_score_mod_inner( head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset # If we will do loads we mod, in order to not read OOB - if cutlass.const_expr(buffers is not None and fastdiv_mods is not None): + if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods q_idx_floored = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) @@ -421,9 +421,9 @@ def apply_score_mod_inner( else: head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) - buffer_args = [] - if cutlass.const_expr(buffers is not None): - buffer_args = buffers + aux_args = [] + if cutlass.const_expr(aux_tensors is not None): + aux_args = aux_tensors post_mod_scores = score_mod( score_ssa, @@ -431,7 +431,7 @@ def apply_score_mod_inner( head_idx_ssa, q_idx=q_idx_ssa, kv_idx=kv_idx_ssa, - buffers=buffer_args, + aux_tensors=aux_args, ) # Write back modified scores diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 644936d8d2d..6c3a679a613 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -7,6 +7,7 @@ import torch from einops import rearrange, repeat + try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: @@ -19,7 +20,11 @@ pad_input, unpad_input, ) -from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, +) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -77,7 +82,17 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype + seqlen_q, + seqlen_k, + d, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, ): if (causal or local) and seqlen_k < seqlen_q: pytest.skip("Causal attention requires seqlen_k >= seqlen_q") @@ -99,26 +114,54 @@ def test_flash_attn_output( # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) + q_ref = q_ref * softcap / 4 q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] @@ -131,11 +174,13 @@ def test_flash_attn_output( None, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, - softcap=softcap + softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, @@ -145,7 +190,9 @@ def test_flash_attn_output( None, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, @@ -197,7 +244,9 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn @@ -225,7 +274,9 @@ def test_flash_attn_output( # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -240,12 +291,24 @@ def test_flash_attn_output( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -300,9 +363,22 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, ): - if (causal or local): # Right now we only support causal attention with seqlen_k == seqlen_q + if ( + causal or local + ): # Right now we only support causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q device = "cuda" # set seed @@ -320,25 +396,53 @@ def test_flash_attn_varlen_output( # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] @@ -349,7 +453,11 @@ def test_flash_attn_varlen_output( # TODO: test zero_lengths key_padding_mask = generate_random_padding_mask( # seqlen_k, batch_size, device, mode="random", zero_lengths=True - seqlen_k, batch_size, device, mode="random", zero_lengths=False + seqlen_k, + batch_size, + device, + mode="random", + zero_lengths=False, ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @@ -394,9 +502,20 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] out_ref, attn_ref = attention_ref( q_ref, k_ref, @@ -405,11 +524,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, - softcap=softcap + softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, @@ -419,7 +540,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, @@ -473,8 +596,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn @@ -510,7 +634,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # deterministic, # 0, # sm_margin # ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) @@ -534,9 +660,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -551,12 +678,24 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -664,45 +803,107 @@ def test_flash_attn_kvcache( for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): # has_qv = d == 64 and dv >= 256 has_qv = False - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) if has_qv: - qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv = None if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) - qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) else: query_padding_mask = None q_unpad = q qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) cu_seqlens_k_new = None key_new_padding_mask = None if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) v_unpad, *rest = unpad_input(v, key_new_padding_mask) else: k_unpad, v_unpad = k, v else: k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) page_table = None else: ( @@ -713,13 +914,25 @@ def test_flash_attn_kvcache( v_cache_paged, num_blocks, ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) if new_kv else (seqlen_k + 1) ), @@ -728,15 +941,26 @@ def test_flash_attn_kvcache( device=device, ) if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) + cache_leftpad = torch.cat( + [ + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size) + ] + ) else: cache_leftpad = None if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] else: cache_batch_idx = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") @@ -744,11 +968,14 @@ def test_flash_attn_kvcache( if not new_kv: key_padding_mask = arange < cache_seqlens_expanded else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens if has_leftpad: key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 @@ -766,7 +993,11 @@ def test_flash_attn_kvcache( sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, ) else: q_ro = rearrange( @@ -782,17 +1013,26 @@ def test_flash_attn_kvcache( ) # q_ro = q k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() if new_kv: update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, ) k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") v_to_update = rearrange(v, "b s ... -> (b s) ...") @@ -801,8 +1041,12 @@ def test_flash_attn_kvcache( v_to_update = v_to_update[indices_k] k_cache_ref[update_mask] = k_to_update v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) out_ref, _ = attention_ref( q_ro, k_cache_rep, @@ -830,7 +1074,7 @@ def test_flash_attn_kvcache( upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -852,7 +1096,9 @@ def test_flash_attn_kvcache( num_splits_vals = [1] # precompute_metadata_vals = [False, True] precompute_metadata_vals = [False] - for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): # if precompute_metadata: # scheduler_metadata = get_scheduler_metadata( # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, @@ -922,19 +1168,35 @@ def test_flash_attn_kvcache( if new_kv: if page_size is None: k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] ) v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] ) else: k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) @@ -943,7 +1205,9 @@ def test_flash_attn_kvcache( if dtype is not torch.float8_e4m3fn: assert torch.equal(v_cache_select, v_cache_ref) else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) # breakpoint() # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: if rotary_dim == 0: @@ -952,23 +1216,37 @@ def test_flash_attn_kvcache( # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref - ).to(dtype).to(dtype_ref) - v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref - ).to(dtype).to(dtype_ref) + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", @@ -994,7 +1272,9 @@ def attention_combine_ref(out_partial, lse_partial): """ lse = torch.logsumexp(lse_partial, dim=0) scale = torch.exp(lse_partial - lse) - scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + scale = torch.where( + torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + ) out = (scale.unsqueeze(-1) * out_partial).sum(0) return out, lse @@ -1019,13 +1299,25 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): # batch_size = 1 # nheads = 1 # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) - out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor - lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + out_partial = torch.randn( + num_splits * 2, + batch_size, + nheads, + seqlen, + d, + device=device, + dtype=torch.float32, + ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn( + num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 + ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor # To test short-circuiting based on num_splits - lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") # Test with LSE returned (default behavior) - out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=True + ) out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) out_pt = out_ref.to(dtype) @@ -1039,9 +1331,16 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) multiple = 2 - assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + assert ( + (out - out_ref).abs().max().item() + <= multiple * (out_pt - out_ref).abs().max().item() + ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) # Test with LSE not returned - out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False) + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=False + ) assert lse_no_lse is None, "LSE should be None when return_lse=False" - assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" \ No newline at end of file + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( + "Output should be the same regardless of return_lse" + ) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 3e6707b5fb9..ce3a28b82c6 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -1,23 +1,22 @@ # mask mod test script +# REFACTORED to use _flash_attn_fwd as the kernel entrypoint import math +from typing import Optional, Callable -import cuda.bindings.driver as cuda -import cutlass -import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack import pytest import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention import torch.nn.functional as F +from flash_attn.cute.interface import _flash_attn_fwd from flash_attn.cute.block_sparsity import compute_block_sparsity -from flash_attn.cute.flash_fwd import ( - FlashAttentionForwardSm80, - FlashAttentionForwardSm90, +from flash_attn.cute.mask_definitions import ( + MASK_FUNCTIONS, + flex_causal_mask, + create_flex_sliding_window_mask, + create_cute_sliding_window_mask, ) -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 -from flash_attn.cute.mask_definitions import MASK_FUNCTIONS, flex_causal_mask, create_flex_sliding_window_mask, create_cute_sliding_window_mask from flash_attn.cute.testing import attention_ref @@ -46,169 +45,12 @@ def create_tensors( } -def compile_and_run_kernel( - tensors, - mask_mod_cute, - causal, - is_local, - window_left, - window_right, - tile_m, - tile_n, - full_block_cnt=None, - full_block_idx=None, - mask_block_cnt=None, - mask_block_idx=None, -): - dtype_map = { - torch.float16: cutlass.Float16, - torch.bfloat16: cutlass.BFloat16, - torch.float32: cutlass.Float32, - } - cute_dtype = dtype_map[tensors["q"].dtype] - - batch_size, seqlen_q, nheads, headdim = tensors["q"].shape - _, seqlen_k, nheads_kv, _ = tensors["k"].shape - headdim_v = tensors["v"].shape[-1] - - compute_capability = torch.cuda.get_device_capability() - if compute_capability >= (10, 0): - kernel_class = FlashAttentionForwardSm100 - elif compute_capability >= (9, 0): - kernel_class = FlashAttentionForwardSm90 - else: - kernel_class = FlashAttentionForwardSm80 - - qhead_per_kvhead = nheads // nheads_kv - kernel = kernel_class( - cute_dtype, - headdim, - headdim_v, - qhead_per_kvhead, - is_causal=causal, - is_local=is_local, - pack_gqa=False, - tile_m=tile_m, - tile_n=tile_n, - num_stages=2, - num_threads=384, - intra_wg_overlap=True, - mma_pv_is_rs=True, - mask_mod=mask_mod_cute, - has_buffers=False, - Q_in_regs=False, - ) - - softmax_scale = 1.0 / math.sqrt(headdim) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=tensors["q"].ndim - 1 - ) - k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=tensors["k"].ndim - 1 - ) - v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=tensors["v"].ndim - 1 - ) - out_cute = from_dlpack( - tensors["out"].detach(), assumed_align=16 - ).mark_layout_dynamic(leading_dim=tensors["out"].ndim - 1) - lse_cute = from_dlpack( - tensors["lse"].detach(), assumed_align=4 - ).mark_layout_dynamic(leading_dim=tensors["lse"].ndim - 1) - - full_block_cnt_cute = ( - from_dlpack(full_block_cnt.detach(), assumed_align=4) - if full_block_cnt is not None - else None - ) - full_block_idx_cute = ( - from_dlpack(full_block_idx.detach(), assumed_align=4) - if full_block_idx is not None - else None - ) - mask_block_cnt_cute = ( - from_dlpack(mask_block_cnt.detach(), assumed_align=4) - if mask_block_cnt is not None - else None - ) - mask_block_idx_cute = ( - from_dlpack(mask_block_idx.detach(), assumed_align=4) - if mask_block_idx is not None - else None - ) - - # Window parameters for is_local - window_left_cute = ( - cutlass.Int32(window_left) if window_left is not None else None - ) - window_right_cute = ( - cutlass.Int32(window_right) if window_right is not None else None - ) - - compiled = cute.compile( - kernel, - q_cute, - k_cute, - v_cute, - out_cute, - lse_cute, - softmax_scale, - current_stream, - None, # cu_seqlens_q - None, # cu_seqlens_k - None, # seqused_q - None, # seqused_k - None, # page_table - window_left_cute, - window_right_cute, - None, # learnable_sink - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, - None, # buffers - ) - - compiled( - q_cute, - k_cute, - v_cute, - out_cute, - lse_cute, - softmax_scale, - current_stream, - None, # cu_seqlens_q - None, # cu_seqlens_k - None, # seqused_q - None, # seqused_k - None, # page_table - window_left_cute, - window_right_cute, - None, # learnable_sink - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, - None, # buffers - ) - - torch.cuda.synchronize() - return tensors["out"] - - -def compute_reference_flash_attn( - tensors, causal, window_size, dtype_ref, upcast=True -): +def compute_reference_flash_attn(tensors, causal, window_size, dtype_ref, upcast=True): """Compute reference using FlashAttention's attention_ref function""" - batch_size, seqlen_q, nheads, headdim = tensors["q"].shape - _, seqlen_k, nheads_kv, _ = tensors["k"].shape - q = tensors["q"].to(dtype_ref) k = tensors["k"].to(dtype_ref) v = tensors["v"].to(dtype_ref) - + out_ref, attn_ref = attention_ref( q, k, @@ -220,13 +62,11 @@ def compute_reference_flash_attn( upcast=upcast, reorder_ops=False, ) - + return out_ref -def compute_reference_flex_attn( - tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n -): +def compute_reference_flex_attn(tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n): """Compute reference using flex_attention for custom mask_mods""" batch_size, seqlen_q, nheads, headdim = tensors["q"].shape _, seqlen_k, nheads_kv, _ = tensors["k"].shape @@ -266,9 +106,7 @@ def mask_fn(b, h, q_idx, kv_idx): k_end = min((k_block + 1) * tile_n, seqlen_k) mask[q_start:q_end, k_start:k_end] = True - attn_mask = ( - mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) - ) + attn_mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) out_ref = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, scale=scale ) @@ -319,11 +157,11 @@ def mask_fn(b, h, q_idx, kv_idx): @pytest.mark.parametrize( "use_mask_mod,is_local,mask_name,window_size,window_left,window_right", [ - (False, False, "identity", None, None, None), - (False, False, "causal", None, None, None), + # (False, False, "identity", None, None, None), + # (False, False, "causal", None, None, None), (True, False, "identity", None, None, None), (True, False, "causal", None, None, None), - # (True, False, "block_causal", None, None, None), + (True, False, "block_causal", None, None, None), # Mask mod sliding window (True, False, "sliding_window", 128, None, None), (True, False, "sliding_window", 256, None, None), @@ -334,39 +172,46 @@ def mask_fn(b, h, q_idx, kv_idx): # (False, True, None, None, 512, 0), ], ) -@pytest.mark.parametrize("tile_m,tile_n", [(128, 128),]) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) def test_mask_mod_output( - seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, - use_mask_mod, is_local, mask_name, window_size, window_left, window_right, - tile_m, tile_n + seqlen_q, + seqlen_k, + nheads, + kv_mode, + headdim, + dtype, + use_mask_mod, + is_local, + mask_name, + window_size, + window_left, + window_right, + tile_m, + tile_n, ): torch.manual_seed(42) # Validate configuration if is_local: assert not use_mask_mod, "Cannot use both is_local and use_mask_mod" - assert window_left is not None or window_right is not None, \ + assert window_left is not None or window_right is not None, ( "Must specify window_left or window_right for is_local" - + ) + if use_mask_mod and mask_name == "sliding_window": - assert window_size is not None, "window_size must be specified for sliding_window" - # Skip if seqlen_k is too small for the window - # if seqlen_k < window_size // 2: - # pytest.skip(f"seqlen_k={seqlen_k} too small for window_size={window_size}") - # Skip if seqlen_q > seqlen_k (problematic for sliding window) + assert window_size is not None, ( + "window_size must be specified for sliding_window" + ) if seqlen_q > seqlen_k: - pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window") - + pytest.skip( + f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window" + ) + if is_local: - window_left_val = window_left if window_left is not None else 0 - window_right_val = window_right if window_right is not None else 0 - total_window = window_left_val + window_right_val + 1 - # Skip if seqlen_k is too small for the window - if seqlen_k < total_window // 2: - pytest.skip(f"seqlen_k={seqlen_k} too small for window={total_window}") - # Skip if seqlen_q > seqlen_k (problematic for local window) if seqlen_q > seqlen_k: - pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local") + pytest.skip( + f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local" + ) # Determine nheads_kv based on mode if kv_mode == "mha": @@ -378,7 +223,7 @@ def test_mask_mod_output( else: raise ValueError(f"Unknown kv_mode: {kv_mode}") - batch_size = 2 + batch_size = 1 headdim_v = headdim # Determine mask_mod functions and causal flag @@ -389,7 +234,7 @@ def test_mask_mod_output( mask_mod_flex = create_flex_sliding_window_mask(window_size) else: mask_mod_cute, mask_mod_flex = MASK_FUNCTIONS[mask_name] - causal = (mask_name == "causal") + causal = False elif is_local: # Base local attention - no mask_mod mask_mod_cute = None @@ -399,7 +244,7 @@ def test_mask_mod_output( mask_mod_cute = None mask_mod_flex = None causal = (mask_name == "causal") if mask_name else False - + if causal and seqlen_k < seqlen_q: pytest.skip("causal masking requires seqlen_k >= seqlen_q") @@ -443,26 +288,61 @@ class Config: config=config, mask_mod_flex=mask_mod_flex, device="cuda" ) - # Run kernel - out_cute = compile_and_run_kernel( - tensors, - mask_mod_cute, + softmax_scale = 1.0 / math.sqrt(headdim) + + # if full_cnt is not None: + # print(f"Block sparsity info for {mask_name}:") + # print(f" full_cnt shape: {full_cnt.shape}") + # print(f" full_idx shape: {full_idx.shape}") + # print(f" mask_cnt shape: {mask_cnt.shape}") + # print(f" mask_idx shape: {mask_idx.shape}") + # print(f" full_cnt: {full_cnt}") + # print(f" full_idx: {full_idx}") + # print(f" mask_cnt: {mask_cnt}") + # print(f" mask_idx: {mask_idx}") + # if full_cnt[0,0,0] > 0: + # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") + # if mask_cnt[0,0,0] > 0: + # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") + + out_tuple = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"], + lse=tensors["lse"], + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, causal=causal, - is_local=is_local, - window_left=window_left, - window_right=window_right, - tile_m=tile_m, - tile_n=tile_n, + softcap=None, + window_size_left=window_left, + window_size_right=window_right, + learnable_sink=None, + m_block_size=tile_m, + n_block_size=tile_n, + num_threads=384, + pack_gqa=False, + _compute_capability=None, + score_mod=None, + mask_mod=mask_mod_cute, full_block_cnt=full_cnt, full_block_idx=full_idx, mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, + return_lse=True, + aux_tensors=None, ) + out_cute = out_tuple[0] + # Determine which reference implementation to use dtype_ref = torch.bfloat16 use_flash_attn_ref = False - + # Use FlashAttention reference for causal and local window cases if mask_name == "causal" and not use_mask_mod: use_flash_attn_ref = True @@ -472,8 +352,6 @@ class Config: window_size_ref = (None, None) # No window for identity elif is_local: use_flash_attn_ref = True - # For is_local, we need to pass the window parameters - # When window_right=0, this is inherently causal window_size_ref = (window_left, window_right) if window_right == 0: causal = True # Override causal flag for reference computation @@ -484,19 +362,31 @@ class Config: # Sliding window with window_right=0 is inherently causal window_size_ref = (window_size, 0) causal = True # Override causal flag for reference computation - + if use_flash_attn_ref: # Compute reference using FlashAttention's attention_ref out_ref_fp32 = compute_reference_flash_attn( - tensors, causal=causal, window_size=window_size_ref, dtype_ref=torch.float32, upcast=True + tensors, + causal=causal, + window_size=window_size_ref, + dtype_ref=torch.float32, + upcast=True, ) out_ref = compute_reference_flash_attn( - tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype_ref, upcast=False + tensors, + causal=causal, + window_size=window_size_ref, + dtype_ref=dtype_ref, + upcast=False, ) - + # Also compute PyTorch reference for comparison (with reorder_ops for better accuracy) out_pt = compute_reference_flash_attn( - tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype, upcast=False + tensors, + causal=causal, + window_size=window_size_ref, + dtype_ref=dtype, + upcast=False, ) else: # Use flex_attention for custom mask_mods @@ -504,7 +394,7 @@ class Config: k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() } - + out_ref_fp32 = compute_reference_flex_attn( tensors_fp32, mask_mod_flex, mask_name, tile_m, tile_n ) @@ -537,18 +427,20 @@ class Config: mask_desc += f"(w={window_size})" else: mask_desc = mask_name if mask_name else "identity" - + print( f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " f"D={headdim}, M={tile_m}, N={tile_n}" ) - print(f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}") + print( + f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}" + ) print(f" Reference vs FP32: {ref_error:.2e}") print(f" PyTorch vs FP32: {pt_error:.2e}") print(f" Kernel vs FP32: {cute_error:.2e}") print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") - + # Debug: show some sample values if error is large if cute_error > 1e-2: print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") @@ -567,4 +459,4 @@ class Config: if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 0d8b2234467..147e5519394 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -9,14 +9,14 @@ @cute.jit -def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tSrS_ssa = tmp0 return tSrS_ssa @cute.jit -def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = operator.ge(tmp0, tmp1) @@ -27,7 +27,7 @@ def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = q_idx tmp2 = kv_idx @@ -40,7 +40,7 @@ def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = q_idx tmp2 = kv_idx @@ -54,7 +54,7 @@ def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = tmp0 * cute.full_like(tmp0, 2) tSrS_ssa = tmp1 @@ -62,7 +62,7 @@ def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = tmp0.to(cutlass.Float32) tmp2 = h_idx @@ -84,7 +84,7 @@ def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = tmp0 - tmp1 @@ -97,7 +97,7 @@ def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = tSrS_ssa @@ -109,7 +109,7 @@ def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = tmp0 - tmp1 @@ -121,8 +121,8 @@ def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): - batch_bias = buffers[0] +def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + batch_bias = aux_tensors[0] # Detect dtype from buffer element type dtype = batch_bias.element_type @@ -137,9 +137,9 @@ def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): - head_bias = buffers[0] - pos_bias = buffers[1] +def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] # Detect dtype from buffer element type dtype = head_bias.element_type @@ -232,8 +232,8 @@ def dual_buffer_mod(score, b, h, q_idx, kv_idx): (score_mod_9, causal_mask_v2_eager), ] -# Test pairs with buffers: (cute_jit_function, eager_reference_function_factory) -TEST_PAIRS_WITH_BUFFERS = [ +# Test pairs with aux_tensors: (cute_jit_function, eager_reference_function_factory) +TEST_PAIRS_WITH_AUX_TENSORS = [ (score_mod_10, batch_bias), (score_mod_11, dual_buffer_bias), ] @@ -248,7 +248,9 @@ def create_tensors( return q, k, v -def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> torch.Tensor: +def run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False +) -> torch.Tensor: q_transposed, k_transposed, v_transposed = map( lambda x: x.transpose(1, 2), (q, k, v) ) @@ -261,7 +263,7 @@ def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> tor score_mod=cute_score_mod, out=out, lse=None, - buffers=buffers, + aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) return out.transpose(1, 2) @@ -270,7 +272,9 @@ def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> tor def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: if dtype is not None: q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) + return flex_attention( + q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] + ) @pytest.mark.parametrize( @@ -301,7 +305,9 @@ def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) -def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair): +def test_cute_vs_flex_attention( + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair +): torch.random.manual_seed(42) cute_score_mod, eager_score_mod = score_mod_pair @@ -375,8 +381,8 @@ def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_he ) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_BUFFERS) -def test_cute_vs_flex_attention_with_buffers( +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +def test_cute_vs_flex_attention_with_aux_tensors( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair ): torch.random.manual_seed(42) @@ -398,13 +404,13 @@ def test_cute_vs_flex_attention_with_buffers( if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 - buffers = [buffer] + aux_tensors = [buffer] eager_score_mod = eager_score_mod_factory(buffer) assert buffer.shape == (batch_size,) elif cute_score_mod == score_mod_11: head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 - buffers = [head_bias, pos_scale] + aux_tensors = [head_bias, pos_scale] eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) assert head_bias.shape == (num_q_heads,) assert pos_scale.shape == (seqlen_q,) @@ -412,7 +418,9 @@ def test_cute_vs_flex_attention_with_buffers( out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash(q, k, v, cute_score_mod, buffers=buffers, pack_gqa=pack_gqa) + out_cute = run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + ) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -443,7 +451,9 @@ def test_cute_vs_flex_attention_with_buffers( ) -@pytest.mark.xfail(raises=NotImplementedError, reason="Varlen with score_mod not yet supported") +@pytest.mark.xfail( + raises=NotImplementedError, reason="Varlen with score_mod not yet supported" +) def test_varlen_with_score_mod(): """Test that varlen (variable length sequences) works with score_mod. @@ -458,7 +468,11 @@ def test_varlen_with_score_mod(): num_heads = 4 dtype = torch.bfloat16 - cu_seqlens = torch.tensor([0] + list(torch.tensor(seqlens).cumsum(0).tolist()), device="cuda", dtype=torch.int32) + cu_seqlens = torch.tensor( + [0] + list(torch.tensor(seqlens).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) q = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) k = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) v = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype)