From aae8105bc5cbf0c88797587f1a345e7fc38b703f Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 2 Dec 2025 20:52:58 +0000 Subject: [PATCH] ruff all the smaller files --- .pre-commit-config.yaml | 9 -- flash_attn/cute/copy_utils.py | 6 +- flash_attn/cute/flash_fwd_combine.py | 154 +++++++++++++++++++-------- flash_attn/cute/hopper_helpers.py | 1 - flash_attn/cute/pack_gqa.py | 2 - flash_attn/cute/testing.py | 20 +++- flash_attn/cute/utils.py | 91 ++++++++++++---- 7 files changed, 193 insertions(+), 90 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67dcf8ba868..6118dfa2283 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,19 +7,10 @@ repos: files: ^flash_attn/cute/.*\.py$ exclude: &cute_exclude | (?x)^flash_attn/cute/( - __init__| - copy_utils| - cute_dsl_utils| - fast_math| flash_bwd| flash_fwd| - flash_fwd_combine| flash_fwd_sm100| - hopper_helpers| interface| - pack_gqa| - testing| - utils )\.py$ - id: ruff-format files: ^flash_attn/cute/.*\.py$ diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 45ec493aaa3..cfdcbdb80a0 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -1,11 +1,11 @@ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math -from typing import Optional, Type, Tuple, Callable +from typing import Optional, Type, Callable import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Boolean, const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cutlass_dsl import T, dsl_user_op @@ -279,7 +279,7 @@ def copy_bulk(src_idx, dst_idx, **new_kwargs): dst[None, dst_idx].iterator, size=size, **new_kwargs, - **kwargs + **kwargs, ) def copy_bulk_single_stage(**new_kwargs): diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 02672e319de..f97e127175d 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -55,8 +55,13 @@ def __init__( @staticmethod def can_implement( - dtype, dtype_partial, head_dim, m_block_size, k_block_size, - log_max_splits, num_threads, + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads, ) -> bool: """Check if the kernel can be implemented with the given parameters.""" if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: @@ -83,8 +88,7 @@ def _setup_attributes(self): assert self.k_block_size % async_copy_elems == 0 k_block_gmem = ( - 128 if self.k_block_size % 128 == 0 else - (64 if self.k_block_size % 64 == 0 else 32) + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) ) gmem_threads_per_row = k_block_gmem // async_copy_elems assert self.num_threads % gmem_threads_per_row == 0 @@ -111,16 +115,25 @@ def _setup_attributes(self): num_bits_per_copy=async_copy_elems * self.dtype.width, ) self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( - atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store + atom_universal_copy, + tOpartial_layout, + vOpartial_layout, # 4 vals per store ) # LSE copy setup with async copy (alignment = 1) lse_copy_bits = Float32.width # 1 element per copy, width is in bits m_block_smem = ( - 128 if self.m_block_size % 128 == 0 else - (64 if self.m_block_size % 64 == 0 else - (32 if self.m_block_size % 32 == 0 else - (16 if self.m_block_size % 16 == 0 else 8))) + 128 + if self.m_block_size % 128 == 0 + else ( + 64 + if self.m_block_size % 64 == 0 + else ( + 32 + if self.m_block_size % 32 == 0 + else (16 if self.m_block_size % 16 == 0 else 8) + ) + ) ) gmem_threads_per_row_lse = m_block_smem assert self.num_threads % gmem_threads_per_row_lse == 0 @@ -167,9 +180,7 @@ def _setup_attributes(self): else: smem_lse_swizzle = cute.make_swizzle(3, 2, 3) smem_layout_atom_lse = cute.make_composed_layout( - smem_lse_swizzle, - 0, - cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) + smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) ) self.smem_layout_lse = cute.tile_to_shape( smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) @@ -177,11 +188,9 @@ def _setup_attributes(self): # O partial shared memory layout (simple layout for pipeline stages) self.smem_layout_o = cute.make_ordered_layout( - (self.m_block_size, self.k_block_size, self.stages), - order=(1, 0, 2) + (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2) ) - @cute.jit def __call__( self, @@ -200,38 +209,63 @@ def __call__( raise TypeError("O partial tensor must match dtype_partial") if const_expr(not (mO.element_type == self.dtype)): raise TypeError("O tensor must match dtype") - if const_expr(not mLSE_partial.element_type in [Float32]): + if const_expr(mLSE_partial.element_type not in [Float32]): raise TypeError("LSE partial tensor must be Float32") - if const_expr(mLSE is not None and not mLSE.element_type in [Float32]): + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") # Shape validation - input tensors are in user format, need to be converted to kernel format if const_expr(len(mO_partial.shape) not in [4, 5]): - raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)") + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) if const_expr(len(mLSE_partial.shape) not in [3, 4]): - raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)") + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) if const_expr(len(mO.shape) not in [3, 4]): - raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)") + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): - raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)") + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO_partial, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mO_partial, mO) + ] # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) - O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) - mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) # or (num_splits, total_q, h) -> (total_q, num_splits, h) LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] - mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose)) + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) # Determine if we have variable length sequences varlen = const_expr(cu_seqlens is not None or seqused is not None) @@ -243,9 +277,7 @@ class SharedStorage: sLSE: cute.struct.Align[ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 ] - sMaxValidSplit: cute.struct.Align[ - cute.struct.MemRange[Int32, self.m_block_size], 128 - ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128] sO: cute.struct.Align[ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 ] @@ -255,7 +287,11 @@ class SharedStorage: # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) seqlen = mO_partial.shape[0] num_head = mO_partial.shape[3] - batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1) + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) # Create FastDivmodDivisor objects for efficient division seqlen_divmod = FastDivmodDivisor(seqlen) @@ -330,14 +366,18 @@ def kernel( # Handle semaphore reset if const_expr(semaphore_to_reset is not None): - if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and - k_block == cute.arch.grid_dim()[1] - 1 and - batch_idx == cute.arch.grid_dim()[2] - 1): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and batch_idx == cute.arch.grid_dim()[2] - 1 + ): semaphore_to_reset[0] = 0 # Get number of splits num_splits = ( - num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None) + num_splits_dynamic_ptr[batch_idx] + if const_expr(num_splits_dynamic_ptr is not None) else mLSE_partial.shape[1] ) # Handle variable length sequences using SeqlenInfo @@ -345,7 +385,7 @@ def kernel( batch_idx=batch_idx, seqlen_static=mO_partial.shape[0], cu_seqlens=cu_seqlens, - seqused=seqused + seqused=seqused, ) seqlen, offset = seqlen_info.seqlen, seqlen_info.offset @@ -354,8 +394,9 @@ def kernel( max_idx = seqlen * num_head # Early exit for single split if dynamic - if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx): - + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( + const_expr(not varlen) or m_block * self.m_block_size < max_idx + ): # =============================== # Step 1: Load LSE_partial from gmem to shared memory # =============================== @@ -390,7 +431,11 @@ def kernel( for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): si = tLSEcLSE[0, s, 0][0] # Get split coordinate if si < num_splits: - cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m]) + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) else: tLSEsLSE[None, s, m].fill(-Float32.inf) # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem @@ -424,7 +469,9 @@ def kernel( else: tOhidx[m] = idx // seqlen tOmidx[m] = idx - tOhidx[m] * seqlen - tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint() + tOrOptr[m] = utils.elem_pointer_i64( + mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]) + ).toint() if idx >= max_idx: tOhidx[m] = -1 @@ -483,7 +530,9 @@ def kernel( # Find max LSE value across splits threads_per_col = const_expr(self.smem_threads_per_col_lse) lse_max = utils.warp_reduce( - ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), op=cute.arch.fmax, width=threads_per_col, ) @@ -496,7 +545,9 @@ def kernel( # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col) # Compute exp scales and sum - lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf LOG2_E = math.log2(math.e) lse_sum_cur = 0.0 for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): @@ -506,7 +557,9 @@ def kernel( lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col) lse_sum[m] = utils.logf(lse_sum_cur) + lse_max # Normalize scales - inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + inv_sum = ( + 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + ) ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) # Store the scales exp(lse - lse_logsum) back to smem cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) @@ -584,7 +637,10 @@ def kernel( # Accumulate scaled partial results for m in cutlass.range(num_rows, unroll_full=True): if tOhidx[m] >= 0 and scale[m] > 0.0: - tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32)) + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) # =============================== # Step 7: Write final O to gmem @@ -605,7 +661,9 @@ def kernel( # Write final results for m in cutlass.range(num_rows, unroll_full=True): if tOhidx[m] >= 0: - mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)) + mO_cur_copy = cute.tiled_divide( + mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,) + ) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_store if const_expr(self.is_even_k) or tOpO[k]: @@ -631,7 +689,9 @@ def load_O_partial( o_gmem_ptr = cute.make_ptr( tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 ) - mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_load @@ -640,5 +700,5 @@ def load_O_partial( gmem_tiled_copy_O_partial, # mO_partial_cur_copy[None, k_idx, split], utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx], - tOsO_partial_cur[None, m, k] + tOsO_partial_cur[None, m, k], ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index c98f85b568e..c6a1c301904 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -4,7 +4,6 @@ import cutlass.cute as cute from cutlass import Int32, Float32, Boolean, const_expr from cutlass.cute.nvgpu import warpgroup -from cutlass._mlir.dialects import llvm from cutlass.cutlass_dsl import Numeric, dsl_user_op from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_og diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index 46d8dd38798..765e71307ad 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -1,7 +1,5 @@ # Copyright (c) 2025, Tri Dao. -import math -import operator import cutlass import cutlass.cute as cute diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 690d0145479..214ed09bc9e 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -99,7 +99,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", if i % 5 == 0: lengths[i] = 0 lengths[-1] = 0 - padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) return padding_mask @@ -129,7 +131,9 @@ def generate_qkv( q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask ) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") @@ -138,7 +142,9 @@ def generate_qkv( ) seqused_q = None max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: @@ -256,7 +262,9 @@ def construct_local_mask( sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length + ), ) @@ -368,7 +376,9 @@ def attention_ref( key_leftpad=key_leftpad, device=q.device, ) - local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + local_mask = ( + torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + ) if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index eb8b86cbe0b..f73f66cfccf 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -10,7 +10,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, const_expr from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack @@ -24,9 +24,10 @@ cute.arch.calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2, - rnd=nvvm.RoundingModeKind.RN + rnd=nvvm.RoundingModeKind.RN, ) + def hash_callable(func: Callable) -> str: """Hash a callable based on the source code or bytecode and closure values.""" if hasattr(func, "__wrapped__"): @@ -62,6 +63,7 @@ def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, buffers): return scoremod_premask_fn + def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) @@ -71,7 +73,10 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) ) -def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_modes=None, stride_order=None) -> cute.Tensor: + +def convert_from_dlpack_leading_static( + x, leading_dim, alignment=16, static_modes=None, stride_order=None +) -> cute.Tensor: if stride_order is None: stride_order = x.dim_order() x_ = from_dlpack(x, assumed_align=alignment) @@ -80,6 +85,7 @@ def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_mode x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) return x_ + def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: @@ -258,7 +264,7 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: # the string here. swizzle_str = str(ptr.type.swizzle_type) # Extract the inner part "S" - match = re.search(r'S<(\d+),(\d+),(\d+)>', swizzle_str) + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) if match: b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) return cute.make_swizzle(b, m, s) @@ -298,6 +304,7 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: ) ) + @dsl_user_op def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: return log2f(a, loc=loc, ip=ip) * math.log(2.0) @@ -350,7 +357,11 @@ def fmax_reduce( # We instead force the 3-input max. res = cute.make_fragment(x.shape, Float32) res.store(x) - local_max_0 = fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1]) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) local_max = [ local_max_0, fmax(res[2], res[3]), @@ -438,7 +449,9 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) flat_stride = cute.flatten_to_tuple(x.stride) - assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) # HACK: we assume that applying the offset does not change the pointer alignment byte_offset = offset * x.element_type.width // 8 @@ -517,7 +530,10 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> return cutlass.Uint32( llvm.inline_asm( T.i32(), - [cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip)], + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], "shr.s32 $0, $1, $2;", "=r,r,r", has_side_effects=False, @@ -543,7 +559,9 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> @dsl_user_op -def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32: +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" return cutlass.Int32( llvm.inline_asm( @@ -561,9 +579,11 @@ def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc @overload def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + @overload def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + @cute.jit def cvt_f16(src: cute.Tensor, dst_or_dtype): """Convert Float32 tensor to Float16/BFloat16. @@ -586,7 +606,9 @@ def cvt_f16(src: cute.Tensor, dst_or_dtype): dst = dst_or_dtype assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" - assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) assert src.element_type is Float32, "src must be Float32" dst_i32 = cute.recast_tensor(dst, cutlass.Int32) assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) @@ -606,7 +628,9 @@ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=N @dsl_user_op @cute.jit -def evaluate_polynomial_2(x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Tuple[Float32, Float32]: +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: deg = len(poly) - 1 out = (poly[deg], poly[deg]) for i in cutlass.range_constexpr(deg - 1, -1, -1): @@ -621,7 +645,7 @@ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) llvm.inline_asm( T.f32(), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], - f"add.rm.ftz.f32 $0, $1, $2;", + "add.rm.ftz.f32 $0, $1, $2;", "=f,f,f", has_side_effects=False, is_align_stack=False, @@ -635,7 +659,10 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip= return cutlass.Float32( llvm.inline_asm( T.f32(), - [Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip)], + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], "{\n\t" ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" "mov.b32 x_rounded_i, $1;\n\t" @@ -657,7 +684,12 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip= @dsl_user_op def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: # We assume x <= 127.0 - poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) fp32_round_int = float(2**23 + 2**22) x_clamped = cute.arch.fmax(x, -127.0) # We want to round down here, so that the fractional part is in [0, 1) @@ -674,11 +706,18 @@ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: @dsl_user_op def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: # We assume x <= 127.0 and y <= 127.0 - poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) fp32_round_int = float(2**23 + 2**22) xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) # We want to round down here, so that the fractional part is in [0, 1) - xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM) + xy_rounded = cute.arch.add_packed_f32x2( + xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM + ) # The integer floor of x & y are now in the last 8 bits of xy_rounded # We want the next 2 ops to round to nearest even. The rounding mode is important. xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) @@ -734,8 +773,12 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) return out0, out1 + + @dsl_user_op -def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: assert isinstance(tensor.iterator, cute.Pointer) # We assume that applying the offset does not change the pointer alignment new_ptr = cute.make_ptr( @@ -751,9 +794,9 @@ def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, i def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) flat_stride = cute.flatten_to_tuple(tensor.stride) - assert len(flat_coord_i64) == len( - flat_stride - ), "Coordinate and stride must have the same length" + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) assert isinstance(tensor.iterator, cute.Pointer) # HACK: we assume that applying the offset does not change the pointer alignment @@ -779,18 +822,20 @@ def coord_offset_i64( tensor.memspace, assumed_align=tensor.iterator.max_alignment, ) - new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))) + new_layout = cute.slice_( + tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1)) + ) return cute.make_tensor(new_ptr, new_layout) @cute.jit def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: - """ Convert a scalar to a cute TensorSSA of shape (1,) and given dtype """ + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" vec = cute.make_fragment(1, dtype) vec[0] = a return vec.load() def ssa_to_scalar(val): - """ Could inline but nice for reflecting the above api """ - return val[0] \ No newline at end of file + """Could inline but nice for reflecting the above api""" + return val[0]