diff --git a/.gitignore b/.gitignore index 060470d3c6f..dc508654045 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,10 @@ var/ # IDE-related .idea/ +.vscode/ # Dev venv + +# compile-time generated file +flash_attn_config.py \ No newline at end of file diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 0233da799f2..7ab4352984e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1264,7 +1264,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple mha_bwd( +std::tuple mha_bwd( at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k @@ -1563,7 +1563,7 @@ std::tuple @@ -1727,7 +1727,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," - "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 6de5c5ac380..5ae58bdd129 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -1335,7 +1335,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple mha_bwd( +std::tuple mha_bwd( Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k @@ -1641,7 +1641,7 @@ std::tuple mha_b torch::stable::zero_(softmax_d); } - return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; + return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } std::tuple @@ -1828,16 +1828,13 @@ void boxed_mha_bwd( auto deterministic = to(stack[20]); auto sm_margin = to(stack[21]); - auto [dq_, dk_, dv_, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); + auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); - stack[0] = from(dq_); - stack[1] = from(dk_); - stack[2] = from(dv_); - stack[3] = from(softmax_d); - stack[4] = from(softmax_lse_log2); - stack[5] = from(dq_accum); - stack[6] = from(dk_accum); - stack[7] = from(dv_accum); + stack[0] = from(softmax_d); + stack[1] = from(softmax_lse_log2); + stack[2] = from(dq_accum); + stack[3] = from(dk_accum); + stack[4] = from(dv_accum); } void boxed_mha_combine( @@ -1949,7 +1946,7 @@ STABLE_TORCH_LIBRARY(flash_attn_3, m) { "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," - "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 1158ee02ad2..44d1f027cb0 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, Tri Dao. -from typing import Optional, Union +from typing import Optional, Union, List, Tuple import torch import torch.nn as nn @@ -17,41 +17,68 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def round_multiple(x, m): + return (x + m - 1) // m * m + + +def round_up_headdim(head_size: int) -> int: + from flash_attn_config import CONFIG + + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: + if head_size <= 64: + return 64 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: + if head_size <= 96: + return 96 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: + if head_size <= 128: + return 128 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: + if head_size <= 192: + return 192 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: + if head_size <= 256: + return 256 + return 256 + + +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( - q, - k, - v, - k_new, - v_new, - qv, - out, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - page_table, - kv_batch_idx, - leftpad_k, - rotary_cos, - rotary_sin, - seqlens_rotary, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size=(-1, -1), - attention_chunk=0, - softcap=0.0, - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - ): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -63,14 +90,14 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd( q, k, v, k_new, v_new, qv, - out, + out_, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, @@ -89,8 +116,8 @@ def _flash_attn_forward( v_descale, softmax_scale, causal, - window_size[0], - window_size[1], + window_size_left, + window_size_right, attention_chunk, softcap, rotary_interleaved, @@ -99,59 +126,314 @@ def _flash_attn_forward( pack_gqa, sm_margin, ) - return out, softmax_lse, *rest + if out_accum is None: + out_accum = torch.tensor([], device=out.device) + + if softmax_lse_accum is None: + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _flash_attn_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Symbolic fake implementation of flash attention forward. + Returns tensors with the correct shapes and dtypes without actual computation. + """ + + # Determine if we're in varlen mode + is_varlen_q = cu_seqlens_q is not None + + # Get dimensions from query tensor + if is_varlen_q: + # varlen mode: q is (total_q, num_heads, head_size) + total_q, num_heads, head_size = q.shape + batch_size = cu_seqlens_q.shape[0] - 1 + + if max_seqlen_q is None: + raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided") + seqlen_q = max_seqlen_q + else: + # batch mode: q is (batch_size, seqlen_q, num_heads, head_size) + batch_size, seqlen_q, num_heads, head_size = q.shape + total_q = batch_size * q.shape[1] + # Get value head dimension + head_size_v = v.shape[-1] + + # Determine output dtype (FP8 inputs produce BF16 outputs) + q_type = q.dtype + if q_type == torch.float8_e4m3fn: + out_dtype = torch.bfloat16 + else: + out_dtype = q_type + + # Create output tensor + if out_ is not None: + # If out_ is provided, _flash_attn_forward becomes non-functional + raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.") + + if is_varlen_q: + out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + else: + out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + + # Create softmax_lse tensor + if is_varlen_q: + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device) + else: + softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + + # TODO(guilhermeleobas): Implement "get_num_splits" + # There's an heuristic to compute num_splits when "num_splits <= 0" + # assert that num_splits is > 0 for now + if num_splits <= 0: + raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}") + if num_splits > 1: + if is_varlen_q: + out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device) + else: + out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + else: + # Tensors are not set when num_splits < 1 + out_accum = torch.tensor([], device=out.device) + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + softmax_d, *rest = flash_attn_3_cuda.bwd( dout, q, k, v, out, softmax_lse, + dq, + dk, + dv, cu_seqlens_q, cu_seqlens_k, sequed_q, sequed_k, max_seqlen_q, max_seqlen_k, - dq, - dk, - dv, softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - sm_margin=0, -): - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + is_causal, + window_size_left, + window_size_right, + softcap, + deterministic, + sm_margin, + ) + return softmax_d + + +@torch.library.register_fake("flash_attn_3::_flash_attn_backward") +def _flash_attn_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: + + is_varlen_q = cu_seqlens_q is not None + is_varlen_k = cu_seqlens_q is not None + is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None + + if not is_varlen_q: + batch_size = q.size(0) + seqlen_q = q.size(1) + seqlen_k = k.size(1) + total_q = batch_size * q.size(1) + else: + batch_size = cu_seqlens_q.size(0) - 1 + total_q = q.size(0) + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if window_size_left >= seqlen_k - 1: + window_size_left = -1 + + if window_size_right >= seqlen_q - 1: + window_size_right = -1 + + if is_causal: + window_size_right = 0 + + is_causal = window_size_left < 0 and window_size_right == 0 + + head_size = q.size(-1) + head_size_v = v.size(-1) + head_size_rounded = round_up_headdim(max(head_size, head_size_v)) + + # Hopper gpus uses cuda compute capabilities 9.0 + cap = torch.cuda.get_device_capability(q.device) + arch = cap[0] * 10 + cap[1] + + is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + + if head_size_rounded <= 64: + kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128 + elif head_size_rounded <= 96: + kBlockM_sm90 = 64 + elif head_size_rounded <= 128: + kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80 + else: + kBlockM_sm90 = 64 + + kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64 + kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32 + + if arch >= 90: + kBlockM = kBlockM_sm90 + elif arch == 86 or arch == 89: + kBlockM = kBlockM_sm86 + else: + kBlockM = kBlockM_sm80 + + num_heads = q.shape[-2] + seqlen_q_rounded = round_multiple(seqlen_q, kBlockM) + + total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM) + + dq = torch.empty_like(q) if dq is None else dq + dk = torch.empty_like(k) if dk is None else dk + dv = torch.empty_like(v) if dv is None else dv + + if not is_varlen: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device) + else: + softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device) + + return softmax_d + + +def setup_context(ctx, inputs, output): + q, k, v = inputs[:3] + out, softmax_lse, _, _ = output + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = inputs[-11] + ctx.causal = inputs[-10] + ctx.window_size = [inputs[-9], inputs[-8]] + ctx.attention_chunk = inputs[-7] + ctx.softcap = inputs[-6] + ctx.sm_margin = inputs[-1] + + +def _backward(ctx, dout, *grads): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( dout, q, k, v, out, softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, - cu_seqlens_q, - cu_seqlens_k, - sequed_q, - sequed_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - deterministic, - sm_margin, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + False, # deterministic + ctx.sm_margin, ) - return dq, dk, dv, softmax_d + return dq, dk, dv, *((None,) * 21) + + +_flash_attn_forward.register_autograd(_backward, setup_context=setup_context) + class FlashAttnQKVPackedFunc(torch.autograd.Function): @@ -196,7 +478,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, sm_margin=sm_margin, @@ -242,7 +525,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -290,7 +574,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, @@ -328,7 +613,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -388,7 +674,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, @@ -431,7 +718,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -787,7 +1075,8 @@ def flash_attn_with_kvcache( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, rotary_interleaved=rotary_interleaved, diff --git a/hopper/setup.py b/hopper/setup.py index 519d1c04f42..95729edabe2 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -82,6 +82,42 @@ _maybe_write, ) +def create_build_config_file(): + CONFIG = { + "build_flags": { + "FLASHATTENTION_DISABLE_BACKWARD": DISABLE_BACKWARD, + "FLASHATTENTION_DISABLE_SPLIT": DISABLE_SPLIT, + "FLASHATTENTION_DISABLE_PAGEDKV": DISABLE_PAGEDKV, + "FLASHATTENTION_DISABLE_APPENDKV": DISABLE_APPENDKV, + "FLASHATTENTION_DISABLE_LOCAL": DISABLE_LOCAL, + "FLASHATTENTION_DISABLE_SOFTCAP": DISABLE_SOFTCAP, + "FLASHATTENTION_DISABLE_PACKGQA": DISABLE_PACKGQA, + "FLASHATTENTION_DISABLE_FP16": DISABLE_FP16, + "FLASHATTENTION_DISABLE_FP8": DISABLE_FP8, + "FLASHATTENTION_DISABLE_VARLEN": DISABLE_VARLEN, + "FLASHATTENTION_DISABLE_CLUSTER": DISABLE_CLUSTER, + "FLASHATTENTION_DISABLE_HDIM64": DISABLE_HDIM64, + "FLASHATTENTION_DISABLE_HDIM96": DISABLE_HDIM96, + "FLASHATTENTION_DISABLE_HDIM128": DISABLE_HDIM128, + "FLASHATTENTION_DISABLE_HDIM192": DISABLE_HDIM192, + "FLASHATTENTION_DISABLE_HDIM256": DISABLE_HDIM256, + "FLASHATTENTION_DISABLE_SM8x": DISABLE_SM8x, + "FLASHATTENTION_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, + "FLASH_ATTENTION_DISABLE_HDIMDIFF64": DISABLE_HDIMDIFF64, + "FLASH_ATTENTION_DISABLE_HDIMDIFF192": DISABLE_HDIMDIFF192, + } + } + + with open("flash_attn_config.py", "w") as f: + f.write("# Auto-generated by flash attention 3 setup.py\n") + f.write(f"CONFIG = {repr(CONFIG)}\n") + f.write("\n") + + f.write("def show():\n") + f.write(" from pprint import pprint\n") + f.write(" pprint(CONFIG)\n") + f.write("\n") + def _write_ninja_file(path, cflags, post_cflags, @@ -395,6 +431,7 @@ def nvcc_threads_args(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) + create_build_config_file() check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): @@ -676,7 +713,7 @@ def run(self): "benchmarks", ) ), - py_modules=["flash_attn_interface"], + py_modules=["flash_attn_interface", "flash_attn_config"], description="FlashAttention-3", long_description=long_description, long_description_content_type="text/markdown", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 0b5a0e2af98..78a8e7c2cc4 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -6,6 +6,11 @@ import torch import torch.nn.functional as F from torch._C import parse_schema +from torch.testing._internal.optests.generate_tests import ( + safe_fake_check, + safe_schema_check, + safe_aot_autograd_check, +) from einops import rearrange, repeat try: @@ -38,6 +43,8 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" +ENABLE_OPCHECK = os.getenv("FLASH_ATTENTION_ENABLE_OPCHECK", "FALSE") == "TRUE" +ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -48,6 +55,61 @@ + ([256] if not DISABLE_HDIM256 else []) ) +def should_test_backward(args, kwargs): + v = args[2] + num_splits = kwargs.get("num_splits", 1) + dtype = v.dtype + has_qv = V_colmajor = False # no test runs this with V_colmajor or has_qv == True + attention_chunk = kwargs.get("attention_chunk") + dv = v.size(-1) + + if ( + ENABLE_AUTOGRAD_CHECK + and not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and num_splits > 0 # we don't support num_split == 0 on torch.compile yet + ): + return True + return False + + +def should_run_schema_check(args, kwargs): + v = args[2] + if v.dtype == torch.float8_e4m3fn: + return False + return True + + +def should_run_fake_check(args, kwargs): + if 'num_splits' in kwargs: + return kwargs['num_splits'] > 0 + return True + + +def run_opcheck(fn): + def wrapper(*args, **kwargs): + if should_run_schema_check(args, kwargs): + safe_schema_check(fn, args, kwargs) + + if should_run_fake_check(args, kwargs): + safe_fake_check(fn, args, kwargs) + + if should_test_backward(args, kwargs): + # Expensive check + safe_aot_autograd_check(fn, args, kwargs, dynamic=False) + safe_aot_autograd_check(fn, args, kwargs, dynamic=True) + return fn(*args, **kwargs) + return wrapper + + +if ENABLE_OPCHECK: + flash_attn_func = run_opcheck(flash_attn_func) + flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func) + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) diff --git a/hopper/test_torch_compile_and_export.py b/hopper/test_torch_compile_and_export.py new file mode 100644 index 00000000000..53beef46340 --- /dev/null +++ b/hopper/test_torch_compile_and_export.py @@ -0,0 +1,73 @@ +import torch +from flash_attn_interface import flash_attn_func +from torch import nn + + +class EfficienctMultiHeadAttention(nn.Module): + def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True): + super().__init__() + assert embed_size % num_heads == 0, f"{embed_size=} {num_heads=}" + + self.embed_size = embed_size + self.num_heads = num_heads + self.head_dim = embed_size // num_heads + self.use_flash_attn = use_flash_attn and (flash_attn_func is not None) + + self.qkv_proj = nn.Linear(embed_size, 3 * embed_size) + self.out_proj = nn.Linear(embed_size, embed_size) + self.dropout = dropout + + def forward(self, x, attention_mask=None): + N, seq_length, _ = x.shape + + qkv = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(N, seq_length, self.num_heads, self.head_dim) + k = k.view(N, seq_length, self.num_heads, self.head_dim) + v = v.view(N, seq_length, self.num_heads, self.head_dim) + + if self.use_flash_attn and attention_mask is None: + out = flash_attn_func( + q, k, v + ) + out = out.reshape(N, seq_length, self.embed_size) + out = self.out_proj(out) + return out + + +def create_model(batch_size=16, sequence_length=256, embedding_dim=2048, num_heads=16): + model = EfficienctMultiHeadAttention(embedding_dim, num_heads).cuda().bfloat16() + input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16() + return model, input_tensor + + +def test_export_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + loss = expected.sum() + loss.backward() + + ep = torch.export.export(model, (input_tensor,)) + got = ep.module()(input_tensor,) + assert torch.equal(expected, got) + + loss_2 = got.sum() + loss_2.backward() + + assert torch.equal(loss, loss_2) + + +def test_compile_and_package_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + + exported = torch.export.export(model, (input_tensor,)) + torch._inductor.aoti_compile_and_package( + exported, + package_path="model.pt2", + ) + + compiled_model = torch._inductor.package.load_package("model.pt2") + out = compiled_model(input_tensor,) + assert torch.equal(expected, out)