diff --git a/aiter/ops/flydsl/__init__.py b/aiter/ops/flydsl/__init__.py index 6da6ecdd68..b73cbf5481 100644 --- a/aiter/ops/flydsl/__init__.py +++ b/aiter/ops/flydsl/__init__.py @@ -37,6 +37,7 @@ from .gemm_kernels import flydsl_hgemm, flydsl_preshuffle_gemm_a8 from .moe_kernels import flydsl_moe_stage1, flydsl_moe_stage2 + from .fmha_kernels import flydsl_flash_attn_func # from .linear_attention_kernels import flydsl_gdr_decode @@ -45,5 +46,6 @@ "flydsl_moe_stage1", "flydsl_moe_stage2", "flydsl_hgemm", + "flydsl_flash_attn_func", # "flydsl_gdr_decode", ] diff --git a/aiter/ops/flydsl/fmha_kernels.py b/aiter/ops/flydsl/fmha_kernels.py new file mode 100644 index 0000000000..9b1f1ba527 --- /dev/null +++ b/aiter/ops/flydsl/fmha_kernels.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""High-level FlyDSL Flash Attention APIs (gfx1201 / RDNA4). + +Wraps the FlyDSL `flash_attn_func_gfx1201` kernel with: + - Build cache keyed by (num_heads, head_dim, causal, dtype, waves_per_eu, daz). + - Automatic seq_len padding to the kernel's tile size (multiple of 128). + - BSHD ([B, S, H, D]) input/output convention to match upstream + flash-attention layout. + - Non-causal padding-ratio safety guard: padded K/V tokens contribute to + the softmax denominator and would scale outputs. Calls with + ``n_pad / seq_len_pad > 0.005`` (0.5%) and ``causal=False`` are rejected + with a ``ValueError``. The 0.5% threshold is the bf16 mantissa precision + floor plus 1 bit of margin; production Wan2.1 (S_real=32760, S_pad=32768, + ratio=0.024%) clears it by 20x. See option (d) in + ``2969_padded_softmax_rca.md``. + +The kernel implements self-attention only (Lq == Lk). Cross-attention +(Lq != Lk) is rejected; callers should fall back to PyTorch SDPA. +""" + +from __future__ import annotations + +from functools import lru_cache + +import torch +import torch.nn.functional as F + +from .kernels.flash_attn_func_gfx1201 import build_flash_attn_func_module + +__all__ = [ + "flydsl_flash_attn_func", +] + + +# Tile size baked into the gfx1201 kernel. Seq_len must be a multiple of this. +# Picked to match BLOCK_M=128 in the kernel; padding is invisible to callers. +_KERNEL_BLOCK_M = 128 + +# Maximum tolerated ratio of padded tokens for non-causal attention. +# Padded K/V keys produce QK^T = 0, but exp(0) = 1 leaks into the softmax +# denominator and silently scales the output. 0.5% is the bf16 mantissa +# precision floor (~0.4%) plus 1 bit of margin. Above this the relative +# error grows quickly (50% pad -> 37% rel_err per RCA in +# 2969_padded_softmax_rca.md). Causal mode masks future tokens including +# the padded ones, so it is unaffected. +_MAX_NONCAUSAL_PAD_RATIO = 0.005 + + +def _torch_dtype_to_str(dtype: torch.dtype) -> str: + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float16: + return "f16" + raise ValueError(f"flydsl_flash_attn_func only supports bf16/f16, got {dtype!r}") + + +@lru_cache(maxsize=32) +def _get_kernel( + num_heads: int, + head_dim: int, + causal: bool, + dtype_str: str, + waves_per_eu: int, + daz: bool, +): + return build_flash_attn_func_module( + num_heads=num_heads, + head_dim=head_dim, + causal=causal, + dtype_str=dtype_str, + waves_per_eu=waves_per_eu, + daz=daz, + ) + + +def flydsl_flash_attn_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool = False, + waves_per_eu: int = 2, + daz: bool = True, + stream: torch.cuda.Stream | None = None, +) -> torch.Tensor: + """Run FlyDSL Flash Attention on RDNA4 (gfx1201). + + Args: + q, k, v: tensors with shape ``[batch, seq_len, num_heads, head_dim]`` + (BSHD). All three must share dtype, batch, num_heads, head_dim, + and seq_len. Must reside on a CUDA/HIP device. + causal: apply causal masking when ``True``. + waves_per_eu: kernel occupancy hint passed to the FlyDSL builder. + daz: enable denormals-are-zero on the kernel. + stream: optional CUDA/HIP stream to launch on. Defaults to the current + stream for ``q.device``. + + Returns: + Output tensor with the same shape and dtype as ``q``. + + Raises: + ValueError: if shapes/dtypes/devices are incompatible, the kernel's + ``head_dim`` constraints are not met, or the non-causal padding + ratio ``n_pad / seq_len_pad`` exceeds 0.5% (see module docstring + for rationale). + """ + if not (q.is_cuda and k.is_cuda and v.is_cuda): + raise ValueError("flydsl_flash_attn_func requires CUDA/HIP tensors") + if not (q.device == k.device == v.device): + raise ValueError( + "q/k/v must reside on the same device, got " + f"q={q.device} k={k.device} v={v.device}" + ) + try: + arch = torch.cuda.get_device_properties(q.device.index).gcnArchName + except Exception: + arch = "" + arch_base = arch.lower().split(":")[0] if arch else "" + if not arch_base.startswith("gfx1201"): + raise ValueError(f"flydsl_flash_attn_func requires gfx1201, got {arch!r}") + if not (q.shape == k.shape == v.shape): + raise ValueError( + "flydsl_flash_attn_func is self-attention; q/k/v must share " + f"shape, got q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" + ) + if not (q.dtype == k.dtype == v.dtype): + raise ValueError(f"q/k/v dtype must match: {q.dtype}/{k.dtype}/{v.dtype}") + if q.dim() != 4: + raise ValueError( + f"expected 4D BSHD tensor, got rank {q.dim()} ({tuple(q.shape)})" + ) + + batch, seq_len_real, num_heads, head_dim = q.shape + if head_dim < 64 or head_dim % 32 != 0: + raise ValueError( + f"kernel requires head_dim >= 64 and head_dim % 32 == 0, got {head_dim}" + ) + + dtype_str = _torch_dtype_to_str(q.dtype) + + # Pad seq_len up to the kernel's tile size. Tight padding (<= 0.5% of + # S_pad) is empirically below the bf16 noise floor on production shapes + # (Wan2.1 cos_sim >= 0.999992). Higher ratios are rejected upstream: + # padded K/V tokens produce QK^T = 0 but exp(0) = 1 still contributes + # to the softmax denominator and would scale the output. Padded queries + # produce garbage rows that we slice off before returning. + seq_len_pad = ( + (seq_len_real + _KERNEL_BLOCK_M - 1) // _KERNEL_BLOCK_M + ) * _KERNEL_BLOCK_M + n_pad = seq_len_pad - seq_len_real + if not causal and n_pad > 0 and n_pad / seq_len_pad > _MAX_NONCAUSAL_PAD_RATIO: + raise ValueError( + "flydsl_flash_attn_func: non-causal path with padding ratio " + f"{n_pad}/{seq_len_pad}={n_pad / seq_len_pad:.4f} exceeds 0.5% " + "safety threshold; padded K/V tokens contribute to softmax " + "denominator and would scale outputs. Either set causal=True, " + "pad seq_len to a multiple of 128 before calling, or use a " + "self-attn kernel with explicit attention masking." + ) + if seq_len_pad != seq_len_real: + pad = n_pad + # F.pad pads from the last dim; for BSHD (last=head_dim) the seq dim + # is dim 1, so we pad (D_left, D_right, H_left, H_right, S_left, S_right). + q_p = F.pad(q.contiguous(), (0, 0, 0, 0, 0, pad)) + k_p = F.pad(k.contiguous(), (0, 0, 0, 0, 0, pad)) + v_p = F.pad(v.contiguous(), (0, 0, 0, 0, 0, pad)) + else: + q_p = q.contiguous() + k_p = k.contiguous() + v_p = v.contiguous() + + o_p = torch.empty_like(q_p) + + # Wrap kernel build + launch in q.device context so multi-GPU callers + # whose current device differs from q.device get the kernel compiled + # and launched on the right device/stream. + with torch.cuda.device(q.device.index): + launch_stream = ( + torch.cuda.current_stream(q.device) if stream is None else stream + ) + if launch_stream.device != q.device: + raise ValueError( + f"`stream` must be on {q.device}, got {launch_stream.device}" + ) + exe = _get_kernel( + num_heads=num_heads, + head_dim=head_dim, + causal=causal, + dtype_str=dtype_str, + waves_per_eu=waves_per_eu, + daz=daz, + ) + exe( + q_p.reshape(-1), + k_p.reshape(-1), + v_p.reshape(-1), + o_p.reshape(-1), + batch, + seq_len_pad, + stream=launch_stream, + ) + + if seq_len_pad != seq_len_real: + return o_p[:, :seq_len_real, :, :].contiguous() + return o_p diff --git a/aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py b/aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py new file mode 100644 index 0000000000..3d0bb48dee --- /dev/null +++ b/aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py @@ -0,0 +1,780 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Combined Flash Attention kernel for gfx1201 with optimizations: + +1. BLOCK_N=32 (reduced tile, fewer iterations, better occupancy; 121->100ms) +2. rocdl.exp2 (native ISA exp2 intrinsic, bypasses arith lowering) +3. Software-pipelined GEMM2: preload next V pack while current WMMA executes, + hiding LDS read latency behind matrix compute (100->96ms). +4. Overlapped V global load: pre-issue next iteration's V global loads at end + of current iteration, so V data is in flight during loop back-edge, barrier, + and K cooperative load of the next iteration (96->91ms). + +Note: V interleaved storage (ds_read_b32) was tested but the element-wise +scatter store overhead negates read savings at BN=32. Row-major V with +software-pipelined scalar reads is faster. + +Note: V pre-transpose (scatter store to col-major LDS, vec8 GEMM2 read) was +tested but the 16 scalar stores per thread during coop_store_v add +8.8% +regression vs baseline (102.7ms vs 94.3ms). + +WMMA 16x16x16 register layout (wave32): + - A/B operand: v8bf16 per lane (lane16 = row/col, klane*8 = K-offset) + - C/D result: v8f32 per lane, element si = C[klane*8+si][lane16] + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) +Block: (256,) -- 8 waves x 32 threads/wave. + +Requires: head_dim % 32 == 0, head_dim >= 64. +""" + +import math as host_math +import os + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import ( + arith, + buffer_ops, + const_expr, + gpu, + range_constexpr, + rocdl, +) +from flydsl.expr import math as fmath +from flydsl.expr.typing import T, Vector as Vec +from flydsl.expr.utils.arith import ArithValue, _to_raw as _raw +from .kernels_common import dtype_to_elem_type +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from flydsl._mlir import ir +from flydsl._mlir.dialects import ( + fly as _fly, + llvm as _llvm, + memref as _memref, +) + +KERNEL_NAME = "flash_attn_func_gfx1201_c_exp_a_k_noswizzle_kernel" +_LOG2E = host_math.log2(host_math.e) + + +def _llvm_value(value): + """Unwrap FlyDSL scalar/vector wrappers for LLVM pointer load ops.""" + if hasattr(value, "ir_value") and not isinstance(value, ir.Value): + return value.ir_value() + return value + + +def _llvm_ptr_ty(): + return ir.Type.parse("!llvm.ptr") + + +def _extract_aligned_pointer(tensor) -> ir.Value: + """Extract the aligned LLVM pointer from a FlyDSL tensor/memref.""" + return _fly.extract_aligned_pointer_as_index(_llvm_ptr_ty(), _llvm_value(tensor)) + + +def _pointer_load(result_type: ir.Type, ptr: ir.Value) -> ir.Value: + return _llvm.LoadOp(result_type, _llvm_value(ptr)).result + + +def _pointer_store(value: ir.Value, ptr: ir.Value): + return _llvm.StoreOp(_llvm_value(value), _llvm_value(ptr)) + + +def build_flash_attn_func_module_primary( + num_heads, + head_dim, + causal=True, + dtype_str="bf16", + sm_scale=None, + waves_per_eu=2, + flat_work_group_size=None, + block_m=None, + block_n=None, + unsafe_fp_math=True, + fast_fp_math=True, + daz=True, + path_tag="auto", +): + """Build gfx1201 flash_attn_func (BN=32 + rocdl.exp2 + pipelined GEMM2 + overlapped V load).""" + gpu_arch = get_hip_arch() + + # ---- WMMA / wave32 constants ---- + WARP_SIZE = 32 + WMMA_M = 16 + WMMA_N = 16 + WMMA_K = 16 + K_SUB_N = 32 + ROWS_PER_WAVE = WMMA_M + + BLOCK_M = block_m if block_m is not None else 128 + BLOCK_N = block_n if block_n is not None else 32 + + assert ( + BLOCK_N % K_SUB_N == 0 + ), f"BLOCK_N ({BLOCK_N}) must be a multiple of K_SUB_N ({K_SUB_N})" + assert ( + BLOCK_M % ROWS_PER_WAVE == 0 + ), f"BLOCK_M ({BLOCK_M}) must be a multiple of {ROWS_PER_WAVE}" + + N_SUB_TILES = BLOCK_N // K_SUB_N + NUM_S_ACCS = N_SUB_TILES * 2 + NUM_S_VALS = NUM_S_ACCS * 8 + + NUM_WAVES = BLOCK_M // ROWS_PER_WAVE + if flat_work_group_size is None: + flat_work_group_size = NUM_WAVES * WARP_SIZE + BLOCK_SIZE = flat_work_group_size + + PATH_TAG = f"M{BLOCK_M}N{BLOCK_N}_combined" + BLOCK_N_OUT = BLOCK_N + + NUM_PREFETCH_K = 1 + NUM_PREFETCH_V = 1 + + K_STEP_QK = WMMA_K + K_STEPS_QK = head_dim // K_STEP_QK + WMMA_LANE_K = 8 + + D_CHUNK = WMMA_N + D_CHUNKS = head_dim // D_CHUNK + + PV_K_STEP = WMMA_K + PV_K_STEPS = K_SUB_N // PV_K_STEP + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 32 == 0 + assert head_dim >= 64 + assert dtype_str in ("f16", "bf16") + + if sm_scale is None: + sm_scale = 1.0 / host_math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # LDS layout -- K uses padding instead of XOR swizzle; V row-major with padding + K_STRIDE = HEAD_DIM + 4 # padding to reduce bank conflicts (no swizzle) + V_STRIDE = HEAD_DIM + 4 # padding to reduce bank conflicts + + ENABLE_LDS_VEC16 = os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_LDS_VEC16", "1") == "1" + VEC_WIDTH = 16 if ENABLE_LDS_VEC16 else 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + LDS_K_TILE_SIZE = BLOCK_N * K_STRIDE + LDS_V_TILE_SIZE = BLOCK_N * V_STRIDE + LDS_K_TOTAL_SIZE = NUM_PREFETCH_K * LDS_K_TILE_SIZE + LDS_V_BASE = LDS_K_TOTAL_SIZE + LDS_V_TOTAL_SIZE = NUM_PREFETCH_V * LDS_V_TILE_SIZE + LDS_KV_TOTAL_SIZE = LDS_K_TOTAL_SIZE + LDS_V_TOTAL_SIZE + + allocator = SmemAllocator( + None, + arch=gpu_arch, + global_sym_name=f"flash_attn_func_gfx1201c_exp_a_smem_{PATH_TAG}", + ) + lds_kv_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_kv_offset + LDS_KV_TOTAL_SIZE * 2 + + # Map dtype string to a FlyDSL Numeric class (for Vec.make_type and `.to(...)`). + # aiter's `dtype_to_elem_type` returns a raw MLIR `ir.Type`; the FlyDSL Vector + # API requires a Numeric subclass instead. Both forms are kept available. + _NUMERIC_MAP = { + "f32": fx.Float32, + "f16": fx.Float16, + "bf16": fx.BFloat16, + } + elem_numeric_cls = _NUMERIC_MAP[dtype_str] + + @flyc.kernel(known_block_size=[BLOCK_SIZE, 1, 1]) + def flash_attn_func_kernel( + Q: fx.Tensor, + K: fx.Tensor, + V: fx.Tensor, + O: fx.Tensor, # noqa: E741 + seq_len: fx.Int32, + ): + elem_type = dtype_to_elem_type(dtype_str) + elem_dtype = elem_numeric_cls + q_ptr = _extract_aligned_pointer(Q) + k_ptr = _extract_aligned_pointer(K) + v_ptr = _extract_aligned_pointer(V) + o_ptr = _extract_aligned_pointer(O) + fm_fast = arith.FastMathFlags.fast + + # Local fast-math arithmetic helpers — preserve fastmath flag while using + # the lowercase op names that accept _raw() unwrapping (PR #462 pattern). + def _fadd(a, b): + return arith.addf(_raw(a), _raw(b), fastmath=fm_fast) + + def _fsub(a, b): + return arith.subf(_raw(a), _raw(b), fastmath=fm_fast) + + def _fmul(a, b): + return arith.mulf(_raw(a), _raw(b), fastmath=fm_fast) + + def _fmax(a, b): + return arith.MaxNumFOp(_raw(a), _raw(b), fastmath=fm_fast).result + + v8f32_type = Vec.make_type(8, fx.Float32) + v8f16_type = Vec.make_type(8, elem_dtype) + vxf16_type = Vec.make_type(VEC_WIDTH, elem_dtype) + + def wmma_acc(a_v8, b_v8, c_v8): + if const_expr(dtype_str == "bf16"): + a_i16 = Vec(a_v8).bitcast(fx.Int16) + b_i16 = Vec(b_v8).bitcast(fx.Int16) + return rocdl.wmma_f32_16x16x16_bf16( + v8f32_type, _raw(a_i16), _raw(b_i16), c_v8 + ).result + return rocdl.wmma_f32_16x16x16_f16(v8f32_type, a_v8, b_v8, c_v8).result + + seq_len_v = fx.Index(seq_len) + + base_ptr = allocator.get_base() + lds_kv = SmemPtr( + base_ptr, + lds_kv_offset, + elem_type, + shape=(LDS_KV_TOTAL_SIZE,), + ).get() + + block_id = fx.Index(gpu.block_idx.x) + tid = fx.Index(gpu.thread_idx.x) + + wave_id = tid // WARP_SIZE + lane = tid % WARP_SIZE + lane16 = lane % 16 + klane = lane // 16 + + wave_q_offset = wave_id * ROWS_PER_WAVE + + head_idx = block_id % NUM_HEADS + batch_q_tile_id = block_id // NUM_HEADS + num_q_tiles = (seq_len_v + BLOCK_M - 1) // BLOCK_M + q_tile_idx = batch_q_tile_id % num_q_tiles + batch_idx = batch_q_tile_id // num_q_tiles + q_start = q_tile_idx * BLOCK_M + + load_row_in_batch = tid // THREADS_PER_ROW_LOAD + load_lane_in_row = tid % THREADS_PER_ROW_LOAD + load_col_base = load_lane_in_row * VEC_WIDTH + + def global_idx(token_idx, col): + token = batch_idx * seq_len_v + token_idx + return token * STRIDE_TOKEN + head_idx * HEAD_DIM + col + + def _load_global_half_vec(ptr, base_idx, vec_type): + gep = buffer_ops.get_element_ptr( + ptr, fx.Int64(base_idx), elem_type=elem_type + ) + return _pointer_load(vec_type, gep) + + def _store_global_half(ptr, base_idx, val): + gep = buffer_ops.get_element_ptr( + ptr, fx.Int64(base_idx), elem_type=elem_type + ) + _pointer_store(val, gep) + + def load_global_f16xN(base_ptr, base_idx): + return _load_global_half_vec(base_ptr, base_idx, vxf16_type) + + def load_global_v8f16(base_ptr, base_idx): + return _load_global_half_vec(base_ptr, base_idx, v8f16_type) + + def _bitcast_i32(value): + return fx.Int32(ArithValue(value).bitcast(fx.Int32.ir_type)) + + def _pack_bf16_pair(lo, hi, shift, mask): + lo_i32 = _bitcast_i32(lo) + hi_i32 = _bitcast_i32(hi) + return (hi_i32 & mask) | lo_i32.shrui(shift) + + def bf16_trunc_pack_v8(f32_vals): + """Pack 8 f32 values into v8bf16 via bitwise truncation (upper 16 bits).""" + _c16 = fx.Int32(16) + _cmask = fx.Int32(0xFFFF0000) + pairs = [] + for j in range_constexpr(4): + pairs.append( + _pack_bf16_pair(f32_vals[j * 2], f32_vals[j * 2 + 1], _c16, _cmask) + ) + return Vec.from_elements(pairs, fx.Int32).bitcast(elem_dtype).ir_value() + + def k_buf_base(buf_id): + if const_expr(isinstance(buf_id, int)): + return fx.Index(buf_id * LDS_K_TILE_SIZE) + return buf_id * fx.Index(LDS_K_TILE_SIZE) + + def v_buf_base(buf_id): + return fx.Index(LDS_V_BASE + buf_id * LDS_V_TILE_SIZE) + + def coop_load_k(tile_start, buf_id=0): + k_base = k_buf_base(buf_id) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = tile_start + load_row_in_batch + row_offset + if const_expr(KV_NEEDS_GUARD): + row_valid = load_row_in_batch < fx.Index(BLOCK_N) + if row_valid: + g_idx = global_idx(row_idx, load_col_base) + lds_row = load_row_in_batch + row_offset + lds_idx = k_base + lds_row * K_STRIDE + load_col_base + vec = load_global_f16xN(k_ptr, g_idx) + Vec(vec).store(lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + lds_row = load_row_in_batch + row_offset + lds_idx = k_base + lds_row * K_STRIDE + load_col_base + vec = load_global_f16xN(k_ptr, g_idx) + Vec(vec).store(lds_kv, [lds_idx]) + + def _v_store_row_major(v_base, lds_row, vec): + lds_idx = v_base + lds_row * V_STRIDE + load_col_base + Vec(vec).store(lds_kv, [lds_idx]) + + def coop_load_v_global(tile_start): + vecs = [] + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = tile_start + load_row_in_batch + row_offset + g_idx = global_idx(row_idx, load_col_base) + vecs.append(load_global_f16xN(v_ptr, g_idx)) + return vecs + + def coop_store_v_lds(vecs, buf_id=0): + v_base = v_buf_base(buf_id) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + if const_expr(KV_NEEDS_GUARD): + row_valid = load_row_in_batch < fx.Index(BLOCK_N) + if row_valid: + lds_row = load_row_in_batch + row_offset + _v_store_row_major(v_base, lds_row, vecs[batch]) + else: + lds_row = load_row_in_batch + row_offset + _v_store_row_major(v_base, lds_row, vecs[batch]) + + # ---- Q preload ---- + q_row = q_start + wave_q_offset + lane16 + q_row_i32 = fx.Int32(q_row) + # Use explicit signed-less-than predicate to match baseline ISA + # (`v_cmp_gt_i64_e64`). fx.Index defaults to unsigned which would lower + # to `v_cmp_gt_u64_e64` and cause an ISA hash drift even though both + # variants are semantically equivalent for non-negative offsets. + q_in_bounds = arith.cmpi(arith.CmpIPredicate.slt, _raw(q_row), _raw(seq_len_v)) + q_row_safe = fx.Index(ArithValue(q_in_bounds).select(q_row, fx.Index(0))) + c_zero_v8f16 = Vec.filled(8, 0.0, elem_dtype).ir_value() + q_b_packs = [] + for ks in range_constexpr(K_STEPS_QK): + q_col = fx.Index(ks * K_STEP_QK) + klane * WMMA_LANE_K + g_idx = global_idx(q_row_safe, q_col) + raw = load_global_v8f16(q_ptr, g_idx) + q_b_packs.append(ArithValue(q_in_bounds).select(raw, c_zero_v8f16)) + + # ---- Constants ---- + c_neg_inf = fx.Float32(float("-inf")) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_sm_scale_log2e = fx.Float32(sm_scale * _LOG2E) + c_zero_v8f32 = Vec.filled(8, 0.0, fx.Float32) + width_i32 = fx.Int32(WARP_SIZE) + shuf_16_i32 = fx.Int32(16) + + def reduction_peer(v_f32): + return fx.Float32(v_f32).shuffle_xor(shuf_16_i32, width_i32) + + _q_end = q_start + BLOCK_M + if const_expr(CAUSAL): + kv_upper = fx.Index( + ArithValue(_q_end < seq_len_v).select(_q_end, seq_len_v) + ) + else: + kv_upper = seq_len_v + + # ---- Opt4: Pre-issue first V global load before loop ---- + _v_vecs_init = coop_load_v_global(fx.Index(0)) + + init_args = [_raw(c_neg_inf), _raw(c_zero_f)] + for _ in range_constexpr(D_CHUNKS): + init_args.append(_raw(c_zero_v8f32)) + # Carry V prefetch vecs as loop-carried values + for batch in range_constexpr(NUM_BATCHES_KV): + init_args.append(_v_vecs_init[batch]) + + loop_results = init_args + for kv_block_start, inner_iter_args in range( + 0, kv_upper, BLOCK_N_OUT, init=init_args + ): + m_running = inner_iter_args[0] + l_running = inner_iter_args[1] + o_accs = [inner_iter_args[2 + i] for i in range_constexpr(D_CHUNKS)] + _v_vecs_prefetch = [ + inner_iter_args[2 + D_CHUNKS + b] + for b in range_constexpr(NUM_BATCHES_KV) + ] + + coop_load_k(kv_block_start, 0) + gpu.barrier() + k_base = k_buf_base(0) + + # ==== GEMM1: S = K @ Q^T (no swizzle, padding-based) ==== + s_accs = [_raw(c_zero_v8f32) for _ in range(NUM_S_ACCS)] + + for ks in range_constexpr(K_STEPS_QK): + k_col = fx.Index(ks * K_STEP_QK) + klane * WMMA_LANE_K + + for st_idx in range_constexpr(N_SUB_TILES): + st_base_row = st_idx * K_SUB_N + + k_row_a = lane16 + fx.Index(st_base_row) + k_lds_a = k_base + k_row_a * K_STRIDE + k_col + k_pack_a = Vec.load(v8f16_type, lds_kv, [k_lds_a]) + + k_row_b = lane16 + fx.Index(st_base_row + 16) + k_lds_b = k_base + k_row_b * K_STRIDE + k_col + k_pack_b = Vec.load(v8f16_type, lds_kv, [k_lds_b]) + + acc_idx_a = st_idx * 2 + acc_idx_b = st_idx * 2 + 1 + s_accs[acc_idx_a] = wmma_acc( + k_pack_a, q_b_packs[ks], s_accs[acc_idx_a] + ) + s_accs[acc_idx_b] = wmma_acc( + k_pack_b, q_b_packs[ks], s_accs[acc_idx_b] + ) + + # ==== Online softmax ==== + s_raw = [] + for st in range_constexpr(NUM_S_ACCS): + for r in range_constexpr(8): + s_raw.append(Vec(s_accs[st])[r]) + + if const_expr(CAUSAL): + kv_start_i32 = fx.Int32(kv_block_start) + klane_i32 = fx.Int32(klane) + q_start_i32 = fx.Int32(q_start) + max_kv_col_i32 = kv_start_i32 + fx.Int32(BLOCK_N - 1) + tile_needs_mask = max_kv_col_i32 > q_start_i32 + + # SSA-style restructure (PR #462 pattern, lines 700-870): + # FlyDSL's `if` rewriter requires each loop-carried/conditional + # state variable to be a single MLIR Value, not a list. Unfold + # `s_raw[0..NUM_S_VALS-1]` into NUM_S_VALS named scalars, then + # reassign each one inside the `if tile_needs_mask:` branch. + # NUM_S_VALS == NUM_S_ACCS * 8 == 16 for BLOCK_N=32. + s_v0 = s_raw[0] + s_v1 = s_raw[1] + s_v2 = s_raw[2] + s_v3 = s_raw[3] + s_v4 = s_raw[4] + s_v5 = s_raw[5] + s_v6 = s_raw[6] + s_v7 = s_raw[7] + s_v8 = s_raw[8] + s_v9 = s_raw[9] + s_v10 = s_raw[10] + s_v11 = s_raw[11] + s_v12 = s_raw[12] + s_v13 = s_raw[13] + s_v14 = s_raw[14] + s_v15 = s_raw[15] + if tile_needs_mask: + klane_off_i32 = klane_i32 * fx.Int32(8) + # st=0 + _b0 = kv_start_i32 + fx.Int32(0) + klane_off_i32 + s_v0 = ArithValue(_b0 > q_row_i32).select(c_neg_inf, s_v0) + _b1 = kv_start_i32 + fx.Int32(1) + klane_off_i32 + s_v1 = ArithValue(_b1 > q_row_i32).select(c_neg_inf, s_v1) + _b2 = kv_start_i32 + fx.Int32(2) + klane_off_i32 + s_v2 = ArithValue(_b2 > q_row_i32).select(c_neg_inf, s_v2) + _b3 = kv_start_i32 + fx.Int32(3) + klane_off_i32 + s_v3 = ArithValue(_b3 > q_row_i32).select(c_neg_inf, s_v3) + _b4 = kv_start_i32 + fx.Int32(4) + klane_off_i32 + s_v4 = ArithValue(_b4 > q_row_i32).select(c_neg_inf, s_v4) + _b5 = kv_start_i32 + fx.Int32(5) + klane_off_i32 + s_v5 = ArithValue(_b5 > q_row_i32).select(c_neg_inf, s_v5) + _b6 = kv_start_i32 + fx.Int32(6) + klane_off_i32 + s_v6 = ArithValue(_b6 > q_row_i32).select(c_neg_inf, s_v6) + _b7 = kv_start_i32 + fx.Int32(7) + klane_off_i32 + s_v7 = ArithValue(_b7 > q_row_i32).select(c_neg_inf, s_v7) + # st=1 (st_base=16) + _b8 = kv_start_i32 + fx.Int32(16) + klane_off_i32 + s_v8 = ArithValue(_b8 > q_row_i32).select(c_neg_inf, s_v8) + _b9 = kv_start_i32 + fx.Int32(17) + klane_off_i32 + s_v9 = ArithValue(_b9 > q_row_i32).select(c_neg_inf, s_v9) + _b10 = kv_start_i32 + fx.Int32(18) + klane_off_i32 + s_v10 = ArithValue(_b10 > q_row_i32).select(c_neg_inf, s_v10) + _b11 = kv_start_i32 + fx.Int32(19) + klane_off_i32 + s_v11 = ArithValue(_b11 > q_row_i32).select(c_neg_inf, s_v11) + _b12 = kv_start_i32 + fx.Int32(20) + klane_off_i32 + s_v12 = ArithValue(_b12 > q_row_i32).select(c_neg_inf, s_v12) + _b13 = kv_start_i32 + fx.Int32(21) + klane_off_i32 + s_v13 = ArithValue(_b13 > q_row_i32).select(c_neg_inf, s_v13) + _b14 = kv_start_i32 + fx.Int32(22) + klane_off_i32 + s_v14 = ArithValue(_b14 > q_row_i32).select(c_neg_inf, s_v14) + _b15 = kv_start_i32 + fx.Int32(23) + klane_off_i32 + s_v15 = ArithValue(_b15 > q_row_i32).select(c_neg_inf, s_v15) + s_raw = [ + s_v0, + s_v1, + s_v2, + s_v3, + s_v4, + s_v5, + s_v6, + s_v7, + s_v8, + s_v9, + s_v10, + s_v11, + s_v12, + s_v13, + s_v14, + s_v15, + ] + + local_max = s_raw[0] + for r in range_constexpr(NUM_S_VALS - 1): + local_max = _fmax(local_max, s_raw[r + 1]) + peer_max = reduction_peer(local_max) + row_max = _fmax(local_max, peer_max) + m_new_raw = _fmax(m_running, row_max) + + # ---- Opt2: rocdl.exp2 ---- + diff_m_raw = _fsub(m_running, m_new_raw) + diff_m_scaled = _fmul(diff_m_raw, c_sm_scale_log2e) + corr = rocdl.exp2(ir.F32Type.get(), _raw(diff_m_scaled)) + + scaled_max = _fmul(c_sm_scale_log2e, m_new_raw) + neg_scaled_max = _fsub(c_zero_f, scaled_max) + + p_vals = [] + local_sum = _raw(c_zero_f) + for r in range_constexpr(NUM_S_VALS): + diff = fmath.fma(s_raw[r], _raw(c_sm_scale_log2e), neg_scaled_max) + p = rocdl.exp2(ir.F32Type.get(), _raw(diff)) + p_vals.append(p) + local_sum = _fadd(local_sum, p) + + peer_sum = reduction_peer(local_sum) + tile_sum = _fadd(local_sum, peer_sum) + l_corr = _fmul(corr, l_running) + l_new = _fadd(l_corr, tile_sum) + + corr_vec = Vec.from_elements([corr], fx.Float32).broadcast_to(8).ir_value() + for dc in range_constexpr(D_CHUNKS): + o_accs[dc] = _fmul(o_accs[dc], corr_vec) + + # Store V to LDS (row-major, fast vector store) + coop_store_v_lds(_v_vecs_prefetch, 0) + gpu.barrier() + + # ==== Build P packs ==== + p_packs_all = [] + for st_idx in range_constexpr(N_SUB_TILES): + p_packs_st = [] + for pks in range_constexpr(PV_K_STEPS): + acc_idx = st_idx * 2 + pks + p_base = acc_idx * 8 + p_slice = [p_vals[p_base + j] for j in range(8)] + + if const_expr(dtype_str == "bf16"): + p_packs_st.append(bf16_trunc_pack_v8(p_slice)) + else: + elem_list = [] + for j in range_constexpr(8): + elem_list.append(fx.Float32(p_slice[j]).to(elem_dtype)) + p_packs_st.append( + Vec.from_elements(elem_list, elem_dtype).ir_value() + ) + p_packs_all.append(p_packs_st) + + # ==== GEMM2: O += V^T @ P (software pipelined, row-major V) ==== + # Opt3: Prefetch next V pack while current WMMA executes + v_base = v_buf_base(0) + + def _load_v_rowmajor(st_kv_base_val, pks_val, dc_val): + d_pos = fx.Index(dc_val * D_CHUNK) + lane16 + v_elems = [] + for k_sub in range_constexpr(8): + kv_row = ( + fx.Index(st_kv_base_val + pks_val * PV_K_STEP) + + klane * WMMA_LANE_K + + fx.Index(k_sub) + ) + v_lds_idx = v_base + kv_row * V_STRIDE + d_pos + # Kept as raw memref.load: scalar element load with no + # direct Vec equivalent — Vec is for SIMD vectors. + v_elems.append(_memref.load(lds_kv, [_raw(v_lds_idx)])) + return Vec.from_elements(v_elems, elem_dtype).ir_value() + + # Software pipeline: preload first V pack + cur_v_packs = [] + for st_idx in range_constexpr(N_SUB_TILES): + cur_v_packs.append(_load_v_rowmajor(st_idx * K_SUB_N, 0, 0)) + + for pks in range_constexpr(PV_K_STEPS): + for dc in range_constexpr(D_CHUNKS): + next_dc = dc + 1 + next_pks = pks + if const_expr(next_dc >= D_CHUNKS): + next_dc = 0 + next_pks = pks + 1 + has_next = const_expr(next_pks < PV_K_STEPS) + + # Prefetch next V while current WMMA runs + next_v_packs = [] + if const_expr(has_next): + for st_idx in range_constexpr(N_SUB_TILES): + next_v_packs.append( + _load_v_rowmajor(st_idx * K_SUB_N, next_pks, next_dc) + ) + + for st_idx in range_constexpr(N_SUB_TILES): + o_accs[dc] = wmma_acc( + cur_v_packs[st_idx], p_packs_all[st_idx][pks], o_accs[dc] + ) + + if const_expr(has_next): + cur_v_packs = next_v_packs + + m_running = m_new_raw + l_running = l_new + + # ---- Opt4: Issue NEXT iteration's V global load ---- + next_kv_start = kv_block_start + fx.Index(BLOCK_N_OUT) + _v_vecs_next = coop_load_v_global(next_kv_start) + + _yield_args = [m_running, l_running] + o_accs + for batch in range_constexpr(NUM_BATCHES_KV): + _yield_args.append(_v_vecs_next[batch]) + loop_results = yield _yield_args + + # ---- Normalize and store O ---- + l_final = loop_results[1] + o_finals = [loop_results[2 + dc] for dc in range_constexpr(D_CHUNKS)] + + inv_l = arith.divf(_raw(c_one_f), _raw(l_final), fastmath=fm_fast) + inv_l_vec = Vec.from_elements([inv_l], fx.Float32).broadcast_to(8).ir_value() + + if q_in_bounds: + for dc in range_constexpr(D_CHUNKS): + o_norm_vec = _fmul(o_finals[dc], inv_l_vec) + o_trunc = Vec(o_norm_vec).to(elem_dtype).ir_value() + d_col = fx.Index(dc * D_CHUNK) + klane * 8 + o_global = global_idx(q_row, d_col) + _store_global_half(o_ptr, o_global, o_trunc) + + @flyc.jit + def launch_flash_attn_func( + Q: fx.Tensor, + K: fx.Tensor, + V: fx.Tensor, + O: fx.Tensor, # noqa: E741 + batch_size: fx.Int32, + seq_len: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + bs_idx = fx.Index(batch_size) + sl_idx = fx.Index(seq_len) + num_q_tiles = (sl_idx + BLOCK_M - 1) // BLOCK_M + grid_x = bs_idx * num_q_tiles * NUM_HEADS + + launcher = flash_attn_func_kernel(Q, K, V, O, seq_len) + + if const_expr(waves_per_eu is not None): + _wpe = int(waves_per_eu) + if const_expr(_wpe >= 1): + for op in ctx.gpu_module_body.operations: + if const_expr(getattr(op, "OPERATION_NAME", None) == "gpu.func"): + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + T.i32, _wpe + ) + if const_expr(flat_work_group_size is not None): + _fwgs = int(flat_work_group_size) + if const_expr(_fwgs >= 1): + flat_wg_attr = ir.StringAttr.get(f"{_fwgs},{_fwgs}") + for op in ctx.gpu_module_body.operations: + if const_expr(getattr(op, "OPERATION_NAME", None) == "gpu.func"): + op.attributes["rocdl.flat_work_group_size"] = flat_wg_attr + + passthrough_entries = [] + if const_expr(daz): + passthrough_entries.append( + ir.ArrayAttr.get( + [ + ir.StringAttr.get("denormal-fp-math-f32"), + ir.StringAttr.get("preserve-sign,preserve-sign"), + ] + ) + ) + passthrough_entries.append( + ir.ArrayAttr.get( + [ + ir.StringAttr.get("no-nans-fp-math"), + ir.StringAttr.get("true"), + ] + ) + ) + passthrough_entries.append( + ir.ArrayAttr.get( + [ + ir.StringAttr.get("unsafe-fp-math"), + ir.StringAttr.get("true"), + ] + ) + ) + for op in ctx.gpu_module_body.operations: + if const_expr(getattr(op, "OPERATION_NAME", None) == "gpu.func"): + op.attributes["passthrough"] = ir.ArrayAttr.get(passthrough_entries) + + launcher.launch(grid=(grid_x, 1, 1), block=(BLOCK_SIZE, 1, 1), stream=stream) + + _fmha_compile_hints = { + "fast_fp_math": fast_fp_math, + "unsafe_fp_math": unsafe_fp_math, + "llvm_options": {"enable-post-misched": False, "lsr-drop-solution": True}, + } + + def _launch(*args, **kwargs): + with CompilationContext.compile_hints(_fmha_compile_hints): + return launch_flash_attn_func(*args, **kwargs) + + def _compile(Q, K, V, O, batch_size, seq_len, stream=None): # noqa: E741 + with CompilationContext.compile_hints(_fmha_compile_hints): + return flyc.compile( + launch_flash_attn_func, + Q, + K, + V, + O, + batch_size, + seq_len, + fx.Stream(stream), + ) + + _launch.compile = _compile + return _launch + + +build_flash_attn_func_module = build_flash_attn_func_module_primary diff --git a/aiter/ops/flydsl/kernels/kernels_common.py b/aiter/ops/flydsl/kernels/kernels_common.py index 993d9f39cd..eacec80141 100644 --- a/aiter/ops/flydsl/kernels/kernels_common.py +++ b/aiter/ops/flydsl/kernels/kernels_common.py @@ -26,6 +26,22 @@ def get_warp_size(arch=None): return 32 if is_rdna_arch(arch) else 64 +def dtype_to_elem_type(dtype_str: str): + """Map a dtype string to its MLIR scalar type. + + Supported: ``'f32'``, ``'f16'``, ``'bf16'``. + """ + if dtype_str == "f32": + return T.f32 + if dtype_str == "f16": + return T.f16 + if dtype_str == "bf16": + return T.bf16 + raise ValueError( + f"unsupported dtype: {dtype_str!r} (expected 'f32', 'f16', or 'bf16')" + ) + + def _create_llvm_ptr(value, address_space: int = 1): value = buffer_ops._unwrap_value(value) if isinstance(value.type, ir.IndexType): diff --git a/op_tests/flydsl_tests/test_flydsl_fmha.py b/op_tests/flydsl_tests/test_flydsl_fmha.py new file mode 100644 index 0000000000..cce2c24ff1 --- /dev/null +++ b/op_tests/flydsl_tests/test_flydsl_fmha.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Tests for ``flydsl_flash_attn_func`` (gfx1201 / RDNA4).""" + +from __future__ import annotations + +from typing import Tuple + +import pytest +import torch +import torch.nn.functional as F + +pytest.importorskip("flydsl") +from aiter.ops.flydsl import is_flydsl_available, flydsl_flash_attn_func # noqa: E402 + +if not is_flydsl_available(): + pytest.skip("flydsl is not available", allow_module_level=True) + + +def _is_gfx1201() -> bool: + if not torch.cuda.is_available(): + return False + try: + arch = torch.cuda.get_device_properties(0).gcnArchName + except Exception: + return False + return arch.lower().split(":")[0].startswith("gfx1201") + + +pytestmark = pytest.mark.skipif( + not _is_gfx1201(), + reason="flydsl_flash_attn_func is gfx1201/RDNA4 only", +) + + +def _ref_sdpa_bshd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool = False, +) -> torch.Tensor: + """SDPA reference with BSHD inputs/outputs.""" + out_bhsd = F.scaled_dot_product_attention( + q.transpose(1, 2).contiguous(), + k.transpose(1, 2).contiguous(), + v.transpose(1, 2).contiguous(), + is_causal=causal, + ) + return out_bhsd.transpose(1, 2).contiguous() + + +def _make_qkv( + batch: int, + seq_len: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + seed: int = 0, + device: str = "cuda", +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g = torch.Generator(device=device).manual_seed(seed) + shape = (batch, seq_len, num_heads, head_dim) + q = torch.randn(shape, generator=g, dtype=dtype, device=device) + k = torch.randn(shape, generator=g, dtype=dtype, device=device) + v = torch.randn(shape, generator=g, dtype=dtype, device=device) + return q, k, v + + +@pytest.mark.parametrize( + "batch,seq_len,num_heads,head_dim", + [ + # Aligned production-like Wan2.1 1.3B shape, padded to multiple of 128. + (1, 32768, 12, 128), + # Smaller aligned shape (sanity). + (2, 1024, 8, 128), + # Unaligned shape — exercises the auto-padding path. 32760 → 32768. + (1, 32760, 12, 128), + ], +) +def test_flydsl_fmha_correctness_bf16(batch, seq_len, num_heads, head_dim): + q, k, v = _make_qkv(batch, seq_len, num_heads, head_dim, torch.bfloat16) + out = flydsl_flash_attn_func(q, k, v, causal=False) + ref = _ref_sdpa_bshd(q, k, v) + + assert out.shape == ref.shape == (batch, seq_len, num_heads, head_dim) + assert out.dtype == ref.dtype == torch.bfloat16 + + cos = F.cosine_similarity( + out.float().reshape(-1, head_dim), + ref.float().reshape(-1, head_dim), + dim=1, + ) + # bf16 attention is noisy; cosine is the right correctness signal. + assert cos.min().item() > 0.99, f"min_cos={cos.min().item():.6f}" + assert cos.mean().item() > 0.999, f"mean_cos={cos.mean().item():.6f}" + + +def test_flydsl_fmha_rejects_cross_attention(): + q = torch.randn(1, 1024, 12, 128, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 512, 12, 128, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, 512, 12, 128, dtype=torch.bfloat16, device="cuda") + with pytest.raises(ValueError, match="self-attention"): + flydsl_flash_attn_func(q, k, v) + + +def test_flydsl_fmha_rejects_unsupported_head_dim(): + q = torch.randn(1, 256, 8, 48, dtype=torch.bfloat16, device="cuda") + with pytest.raises(ValueError, match="head_dim"): + flydsl_flash_attn_func(q, q.clone(), q.clone()) + + +def test_flydsl_fmha_rejects_dtype_mismatch(): + q = torch.randn(1, 1024, 8, 128, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 1024, 8, 128, dtype=torch.float16, device="cuda") + v = torch.randn(1, 1024, 8, 128, dtype=torch.bfloat16, device="cuda") + with pytest.raises(ValueError, match="dtype"): + flydsl_flash_attn_func(q, k, v) + + +def test_flydsl_fmha_correctness_f16(): + """f16 dtype coverage — Wan2.1 1.3B-style shape, non-causal.""" + batch, seq_len, num_heads, head_dim = 1, 32768, 12, 128 + q, k, v = _make_qkv(batch, seq_len, num_heads, head_dim, torch.float16) + out = flydsl_flash_attn_func(q, k, v, causal=False) + ref = _ref_sdpa_bshd(q, k, v, causal=False) + + assert out.shape == ref.shape == (batch, seq_len, num_heads, head_dim) + assert out.dtype == ref.dtype == torch.float16 + + cos = F.cosine_similarity( + out.float().reshape(-1, head_dim), + ref.float().reshape(-1, head_dim), + dim=1, + ) + assert cos.min().item() > 0.99, f"min_cos={cos.min().item():.6f}" + assert cos.mean().item() > 0.999, f"mean_cos={cos.mean().item():.6f}" + + +def test_flydsl_fmha_correctness_causal_small(): + """Causal masking coverage — small bf16 shape.""" + batch, seq_len, num_heads, head_dim = 2, 4096, 8, 128 + q, k, v = _make_qkv(batch, seq_len, num_heads, head_dim, torch.bfloat16) + out = flydsl_flash_attn_func(q, k, v, causal=True) + ref = _ref_sdpa_bshd(q, k, v, causal=True) + + assert out.shape == ref.shape == (batch, seq_len, num_heads, head_dim) + assert out.dtype == ref.dtype == torch.bfloat16 + + cos = F.cosine_similarity( + out.float().reshape(-1, head_dim), + ref.float().reshape(-1, head_dim), + dim=1, + ) + assert cos.min().item() > 0.99, f"min_cos={cos.min().item():.6f}" + assert cos.mean().item() > 0.999, f"mean_cos={cos.mean().item():.6f}" + + +def test_flydsl_fmha_correctness_multi_device(): + """Multi-GPU device-context wrapping (#1) and same-device check (#6). + + Runs the kernel on device 1 while the default current device is 0 in a + subprocess (so a HIP context-pollution failure cannot leak into the rest + of the test session). Validates the ``with torch.cuda.device(...)`` wrap + in ``flydsl_flash_attn_func`` when q.device != current device. + + If the underlying FlyDSL runtime pins to device 0 internally (a runtime + limitation, not a wrapper bug), the subprocess will raise + ``hipErrorInvalidDevice`` and the test is marked xfail — the wrapper code + path is still correct and the same-device guard test below still + validates Copilot #6 directly. + """ + if torch.cuda.device_count() < 2: + pytest.skip("requires >=2 visible GPUs") + + import subprocess + import textwrap + + script = textwrap.dedent(""" + import sys + sys.path.insert(0, "/workspace/FlyDSL/python") + import flydsl + flydsl.__version__ = "0.1.5.dev999" + + import torch + import torch.nn.functional as F + from aiter.ops.flydsl import flydsl_flash_attn_func + + torch.cuda.set_device(0) + dev1 = torch.device("cuda", 1) + B, S, H, D = 1, 1024, 8, 128 + g = torch.Generator(device=dev1).manual_seed(0) + shape = (B, S, H, D) + q = torch.randn(shape, generator=g, dtype=torch.bfloat16, device=dev1) + k = torch.randn(shape, generator=g, dtype=torch.bfloat16, device=dev1) + v = torch.randn(shape, generator=g, dtype=torch.bfloat16, device=dev1) + + out = flydsl_flash_attn_func(q, k, v, causal=False) + torch.cuda.synchronize(dev1) + assert out.device == dev1, f"expected cuda:1 got {out.device}" + + with torch.cuda.device(dev1): + ref_bhsd = F.scaled_dot_product_attention( + q.transpose(1, 2).contiguous(), + k.transpose(1, 2).contiguous(), + v.transpose(1, 2).contiguous(), + is_causal=False, + ) + ref = ref_bhsd.transpose(1, 2).contiguous() + cos = F.cosine_similarity( + out.float().reshape(-1, D), + ref.float().reshape(-1, D), + dim=1, + ) + cm = cos.min().item() + assert cm > 0.99, f"min_cos={cm:.6f}" + print("MULTI_DEVICE_OK", flush=True) + """) + + proc = subprocess.run( + ["python", "-c", script], + capture_output=True, + text=True, + timeout=120, + ) + combined = (proc.stdout or "") + "\n" + (proc.stderr or "") + if "MULTI_DEVICE_OK" in proc.stdout: + return + if "hipErrorInvalidDevice" in combined or "invalid device ordinal" in combined: + pytest.xfail( + "FlyDSL runtime pins to device 0; wrapper-level device-context " + "switch is in place but underlying runtime does not honor it" + ) + raise AssertionError( + f"multi-device subprocess failed unexpectedly:\n" + f"stdout:\n{proc.stdout}\nstderr:\n{proc.stderr}" + ) + + +def test_flydsl_fmha_rejects_excessive_padding(): + """Non-causal path must reject padding ratio > 0.5% (option (d) guard). + + S_real=129 -> S_pad=256, pad ratio 127/256 = 49.6%. Padded K/V keys + would contribute to the softmax denominator and silently scale outputs + (rel_err ~37% per RCA in 2969_padded_softmax_rca.md). Wrapper must + raise before launching the kernel. + """ + batch, seq_len, num_heads, head_dim = 1, 129, 8, 128 + q, k, v = _make_qkv(batch, seq_len, num_heads, head_dim, torch.bfloat16) + with pytest.raises(ValueError, match="0.5% safety threshold"): + flydsl_flash_attn_func(q, k, v, causal=False) + + +def test_flydsl_fmha_allows_tight_padding(): + """Wan2.1 production case (S_real=32760 -> S_pad=32768, ratio 0.024%) + must pass the 0.5% threshold and produce SDPA-equivalent output. + + Regression guard for option (d) — protects the production hot path + from a future, stricter threshold accidentally rejecting it. + """ + batch, seq_len, num_heads, head_dim = 1, 32760, 12, 128 + q, k, v = _make_qkv(batch, seq_len, num_heads, head_dim, torch.bfloat16) + out = flydsl_flash_attn_func(q, k, v, causal=False) + ref = _ref_sdpa_bshd(q, k, v, causal=False) + + assert out.shape == ref.shape == (batch, seq_len, num_heads, head_dim) + cos = F.cosine_similarity( + out.float().reshape(-1, head_dim), + ref.float().reshape(-1, head_dim), + dim=1, + ) + # Wan2.1 production cos_min was empirically 0.999992 in the RCA; + # 0.9999 is the conservative regression bound (bf16 noise floor). + assert cos.min().item() > 0.9999, f"min_cos={cos.min().item():.6f}" + + +def test_flydsl_fmha_rejects_device_mismatch(): + """Same-device check (#6) — q on device 0, k/v on device 1 must raise.""" + if torch.cuda.device_count() < 2: + pytest.skip("requires >=2 visible GPUs") + + q = torch.randn(1, 1024, 8, 128, dtype=torch.bfloat16, device="cuda:0") + k = torch.randn(1, 1024, 8, 128, dtype=torch.bfloat16, device="cuda:1") + v = torch.randn(1, 1024, 8, 128, dtype=torch.bfloat16, device="cuda:1") + with pytest.raises(ValueError, match="same device"): + flydsl_flash_attn_func(q, k, v)