diff --git a/csrc/cutlass b/csrc/cutlass index afa17722036..62750a2b75c 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 62750a2b75c802660e4894434dc55e839f322277 diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 77640c2b239..2c0a4f1b871 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -23,9 +23,9 @@ flash_attn_with_kvcache = None try: - from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear + from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: - FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None + ColumnParallelLinear, RowParallelLinear = None, None, None try: from flash_attn.layers.rotary import RotaryEmbedding @@ -341,13 +341,6 @@ def forward(self, q, kv, causal=None, key_padding_mask=None): return output -class LinearResidual(nn.Linear): - """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input), input - - def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. @@ -452,13 +445,6 @@ def __init__( device=device, ) - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - linear_resid_cls = ( - LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) - ) - wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn @@ -470,10 +456,10 @@ def __init__( else CrossAttention ) if not self.cross_attn: - self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wqkv = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) else: - self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) - self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wq = nn.Linear(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wkv = nn.Linear(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) if self.dwconv: if self.num_heads_kv == self.num_heads: self.dwconv_qkv = nn.Conv1d( @@ -492,7 +478,7 @@ def __init__( self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype @@ -646,10 +632,7 @@ def forward( batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) + qkv = self.Wqkv(x) if self.dwconv: qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" @@ -680,21 +663,11 @@ def forward( ) else: if self.cross_attn: - if not self.return_residual: - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) - kv = self.Wkv(x_kv if x_kv is not None else x) - else: - if x_kv is not None: - kv, x_kv = self.Wkv(x_kv) - else: - kv, x = self.Wkv(x) - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) else: assert self.num_heads_kv != self.num_heads - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) + qkv = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 1e45b8e6098..6b4033d134e 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -11,9 +11,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd from flash_attn.utils.distributed import ( all_gather_raw, diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 0d122aa0883..0427e957e8e 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -10,11 +10,13 @@ import torch import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd import triton import triton.language as tl +from flash_attn.utils.torch import custom_fwd, custom_bwd + + def triton_autotune_configs(): # Return configs with a valid warp count for the current device configs=[] @@ -635,7 +637,9 @@ def _layer_norm_bwd( BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) @@ -1018,12 +1022,12 @@ def forward( norm_bias, eps, residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, ) y = y.reshape(x_shape_og) - dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype linear_weight = linear_weight.to(dtype) linear_bias = linear_bias.to(dtype) if linear_bias is not None else None out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) diff --git a/flash_attn/ops/triton/mlp.py b/flash_attn/ops/triton/mlp.py index b795310f1c8..059f4f8a5e1 100644 --- a/flash_attn/ops/triton/mlp.py +++ b/flash_attn/ops/triton/mlp.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd +from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 0ee56d64773..560c75d002d 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -38,8 +38,8 @@ def rotary_kernel( BLOCK_M: tl.constexpr, ): pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) + pid_head = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) rotary_dim_half = rotary_dim // 2 if not IS_VARLEN: @@ -193,7 +193,7 @@ def apply_rotary( if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) ) - grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), nheads, batch) # noqa BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4) # Need this, otherwise Triton tries to launch from cuda:0 and we get @@ -223,5 +223,6 @@ def apply_rotary( interleaved, conjugate, BLOCK_M, + num_warps=2 if rotary_dim <= 64 else 4, ) return output diff --git a/flash_attn/utils/torch.py b/flash_attn/utils/torch.py new file mode 100644 index 00000000000..98cbf9a274c --- /dev/null +++ b/flash_attn/utils/torch.py @@ -0,0 +1,21 @@ +import torch +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index 9b7c0570844..99b1b7a3298 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -36,15 +36,15 @@ use_bench_cudagraph = False -attn_variants = ["mha", "gqa", "mqa", "mla"] -for attn_variant in attn_variants: -# for attn_variant in attn_variants[3:]: - nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else 1) - headdim = 64 if attn_variant == "mla" else 128 - headdim_v = 512 if attn_variant == "mla" else headdim - has_qv = headdim == 64 and headdim_v == 512 +attn_variants = ["mha", "gqa", "mqa", "mla", "gla"] +# for attn_variant in attn_variants: +for attn_variant in attn_variants[3:5]: + nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else (1 if attn_variant == "mla" else 2)) + headdim = 64 if attn_variant in ["mla", "gla"] else 128 + headdim_v = 512 if attn_variant == "mla" else (256 if attn_variant == "gla" else headdim) + has_qv = headdim == 64 and headdim_v > 64 # page_size = None - page_size = 64 if attn_variant == "mla" else 128 + page_size = 64 if attn_variant in ["mla", "gla"] else 128 should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None @@ -60,7 +60,7 @@ print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: - # for seqlen in [s * 1024 for s in [1]]: + # for seqlen in [s * 1024 for s in [8]]: cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) num_splits = 0 q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) @@ -84,6 +84,7 @@ cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True ) # scheduler_metadata = None + # breakpoint() fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata) time.sleep(1) # to avoid power throttling # Time in ms @@ -109,7 +110,7 @@ t1 = do_bench_cudagraph(fn1, rep=10) total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() - mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last term is for the output flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 ideal_h100_time_flop = flops / 989e12 * 1e6 diff --git a/hopper/flash.h b/hopper/flash.h index 69562d4881e..91fb5c81277 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -112,6 +112,7 @@ struct Flash_fwd_params : public Qkv_params { // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; + int *__restrict__ seqlens_rotary; // The indices to index into the KV cache. int * __restrict__ kv_batch_idx; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 543a60ea5c4..5a595840aa6 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -272,10 +272,11 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { + if (params.dv > 256 && Arch == 90) { return run_mha_fwd_(params, stream); - } - else { + } else if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { return run_mha_fwd_(params, stream); } } @@ -302,10 +303,11 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { + if (params.dv > 256 && Arch == 90) { return run_mha_fwd_(params, stream); - } - else { + } else if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { return run_mha_fwd_(params, stream); } } @@ -490,6 +492,15 @@ inline int round_up_headdim(int head_size) { return 256; } +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( @@ -534,7 +545,7 @@ mha_fwd_get_scheduler_metadata( params.d = headdim; params.dv = headdim_v; params.d_rounded = round_up_headdim(headdim); - params.dv_rounded = round_up_headdim(headdim_v); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); params.seqlen_knew = max_seqlen_k_new; bool const is_varlen_q = cu_seqlens_q_.has_value(); @@ -640,6 +651,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq std::optional &leftpad_k_, // b std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &seqlens_rotary_, // b std::optional &q_descale_, // (b, h_k), not (b, h) std::optional &k_descale_, // (b, h_k) std::optional &v_descale_, // (b, h_k) @@ -823,7 +835,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = round_up_headdim(head_size_v); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -1001,6 +1013,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } } else { params.rotary_dim = 0; } @@ -1104,7 +1123,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // params.b = 1; // params.seqlen_q = total_q; // } + // This will zero out the semaphore if needed run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -1492,7 +1515,6 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512"); TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index f3f6a18b21b..a2006f3c4ef 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -38,6 +38,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq std::optional &leftpad_k_, // b std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &seqlens_rotary_, // b std::optional &q_descale_, // (b, h_k), not (b, h) std::optional &k_descale_, // (b, h_k) std::optional &v_descale_, // (b, h_k) @@ -104,6 +105,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? leftpad_k," " Tensor? rotary_cos," " Tensor? rotary_sin," + " Tensor? seqlens_rotary," " Tensor? q_descale," " Tensor? k_descale," " Tensor? v_descale," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 92b84096f02..9e8d6908efe 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -36,6 +36,7 @@ def _flash_attn_forward( leftpad_k, rotary_cos, rotary_sin, + seqlens_rotary, q_descale, k_descale, v_descale, @@ -58,6 +59,7 @@ def _flash_attn_forward( maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k) ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + seqlens_rotary = maybe_contiguous(seqlens_rotary) out, softmax_lse, *rest = flash_attn_3_cuda.fwd( q, k, @@ -78,6 +80,7 @@ def _flash_attn_forward( leftpad_k, rotary_cos, rotary_sin, + seqlens_rotary, q_descale, k_descale, v_descale, @@ -171,17 +174,26 @@ def forward( num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert num_heads_k * 2 + num_heads_q == qkv.shape[2] q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) - out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( + out, softmax_lse, *rest = _flash_attn_forward( q, k, v, + None, None, # k_new, v_new + None, # qv + None, # out + None, None, None, # cu_seqlens_q/k/k_new + None, None, # seqused_q/k + None, None, # max_seqlen_q/k + None, None, None, # page_table, kv_batch_idx, leftpad_k, + None, None, None, # rotary_cos/sin, seqlens_rotary + q_descale, k_descale, v_descale, softmax_scale, causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size @@ -211,6 +223,9 @@ def backward(ctx, dout, *args): v, out, softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, @@ -257,7 +272,7 @@ def forward( None, None, # seqused_q/k None, None, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, # rotary_cos/sin + None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, @@ -350,7 +365,7 @@ def forward( max_seqlen_q, max_seqlen_k, None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, # rotary_cos/sin + None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, @@ -602,6 +617,7 @@ def flash_attn_with_kvcache( cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -730,6 +746,7 @@ def flash_attn_with_kvcache( cache_leftpad, rotary_cos, rotary_sin, + rotary_seqlens, q_descale, k_descale, v_descale, softmax_scale, causal=causal, diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index 4c35da4f08a..b308d2d1b88 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -187,6 +187,7 @@ class FlashAttnFwdSm80 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 962283fe279..47b3817cd28 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -337,6 +337,7 @@ class FlashAttnFwdSm90 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.load_kv_new( @@ -385,6 +386,7 @@ class FlashAttnFwdSm90 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 00692049366..e9297e1b7ca 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -126,7 +126,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.kv_batch_idx, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, - params.leftpad_k, + params.leftpad_k, params.seqlens_rotary }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.o_ptr), @@ -208,7 +208,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index 19a6e90d345..b91a5b128f9 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -139,6 +139,7 @@ def get_all_kernels() -> List[Kernel]: if sm == 90 and head_dim == 192: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=256, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu new file mode 100644 index 00000000000..4fb8f71d01e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..8d037153cbb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu new file mode 100644 index 00000000000..c62e0b8d822 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..5e22d67f700 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..1e005b3f018 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..96c4f55afdb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu new file mode 100644 index 00000000000..8a92fe291ee --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..f47cb326674 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..1915feb0463 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu new file mode 100644 index 00000000000..fbc15776610 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..88445691ffb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..f7d051a34d3 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu new file mode 100644 index 00000000000..c83c1741d4f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..2e06c89a8c7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..46479ec15e1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..18681ec42b4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu new file mode 100644 index 00000000000..d2245aa136a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..022cdd39576 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..67a324d52e8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu new file mode 100644 index 00000000000..664f88dbfce --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..6bd6b9ab38f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu index cc3a8a7c913..ddd8bf07c4a 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu index d6d6df0d4ee..c9494c4f1d2 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu index bd85f7608f6..4b2ec583cfd 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu index 733511adb43..306722d4586 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu index c62ccf28d3c..e44b2d24654 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu index b7e51551a04..d52417daef3 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_sm90.cu" #include "flash_fwd_hdim64_512_bf16_sm90.cu" #include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu index 0dbd0045425..6428c461aa9 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu index 51a14371284..d0df6306e28 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu index 24a64e8e49e..e116d3ea7c7 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_split_sm90.cu" #include "flash_fwd_hdim64_512_bf16_split_sm90.cu" #include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu index 50c78f3d5d4..bededf4a7d8 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu index 453282a4f29..ea531027938 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu index 72736d8ef7a..10d86e5e99c 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu index 97895aa708c..375197ef75e 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu index 423c42221e0..4fc4831cf58 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu index 98c89572117..a3d94a163a9 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu index 69108d025fa..9663103ae11 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_sm90.cu" #include "flash_fwd_hdim64_512_fp16_sm90.cu" #include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu index da39ba2731a..b7d2b07ca84 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu index be6496d1956..471b5abaafc 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu index a5a80909072..10f72182fa9 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_split_sm90.cu" #include "flash_fwd_hdim64_512_fp16_split_sm90.cu" #include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu index 62fe142562d..54db60c23b1 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index a642fc74f9c..905be872dd9 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -212,6 +212,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; // Device side kernel params @@ -256,6 +257,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; static Params @@ -295,7 +297,7 @@ struct CollectiveMainloopFwdSm80 { !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } template @@ -472,11 +474,11 @@ struct CollectiveMainloopFwdSm80 { flash::cp_async_wait(); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = cute::conditional_return( @@ -689,12 +691,12 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 6a21078f77a..68988862e58 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -395,6 +395,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; // Device side kernel params @@ -450,6 +451,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const *const seqlens_rotary = nullptr; }; static Params @@ -558,7 +560,7 @@ struct CollectiveMainloopFwdSm90 { !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -778,7 +780,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.producer_commit(smem_pipe_write); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. Without this we get race conditions. - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/); pipeline_vt.consumer_release(smem_pipe_read); }; @@ -1087,11 +1089,11 @@ struct CollectiveMainloopFwdSm90 { barrier_Q.wait(work_idx % 2); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { @@ -1256,8 +1258,8 @@ struct CollectiveMainloopFwdSm90 { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warp_scheduler_barrier_arrive(); if constexpr (!HasQv) { + warp_scheduler_barrier_arrive(); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); // release K } else { @@ -1265,7 +1267,9 @@ struct CollectiveMainloopFwdSm90 { shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); } consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K warpgroup_wait<0>(); } @@ -1579,12 +1583,12 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); // This is used to index into the batch dimension of mK and mV int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; @@ -1654,7 +1658,7 @@ struct CollectiveMainloopFwdSm90 { rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } - // Without this sync I'm getting race condition when seqlen_k is large + // Without this fence I'm getting race condition when seqlen_k is large cutlass::arch::fence_view_async_shared(); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index 8d07f6aa2fc..a7dfb6439a2 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -49,30 +49,24 @@ static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNa enum class FwdNamedBarriers { QueryEmpty = 0, - ProducerWG = 1, - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - WarpSchedulerWG1 = 4, - WarpSchedulerWG2 = 5, - WarpSchedulerWG3 = 6, - AppendKV = 7, - QueryRotated = 8, - PFull = 9, - PEmpty = 6, // HACK: PEmpty is only used when we don't have 3 WGs + WarpSchedulerWG1 = 1, + WarpSchedulerWG2 = 2, + WarpSchedulerWG3 = 3, + AppendKV = 4, + QueryRotated = 5, + PFull = 6, + PEmpty = 7, }; enum class BwdNamedBarriers { KVEmpty = 0, PdS = 1, - // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - dQEmptyWG1 = 4, - dQEmptyWG2 = 5, - dQEmptyWG3 = 6, - dQFullWG1 = 7, - dQFullWG2 = 8, - dQFullWG3 = 9, + dQEmptyWG1 = 2, + dQEmptyWG2 = 3, + dQEmptyWG3 = 4, + dQFullWG1 = 5, + dQFullWG2 = 6, + dQFullWG3 = 7, }; } // flash diff --git a/hopper/seqlen.h b/hopper/seqlen.h index 21a74712800..5547238b348 100644 --- a/hopper/seqlen.h +++ b/hopper/seqlen.h @@ -64,12 +64,13 @@ struct SeqlenInfoQKNewK { int const leftpad_k; int const offset_q, offset_k, offset_k_new; - int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k; + int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; CUTLASS_DEVICE SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, - int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k + int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, + int const* const seqlens_rotary ) : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) @@ -85,6 +86,7 @@ struct SeqlenInfoQKNewK { ? 0 : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) + , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) { } diff --git a/hopper/setup.py b/hopper/setup.py index f87d809ebd5..e12d98b7cff 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -518,7 +518,7 @@ def nvcc_threads_args(): # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers "--resource-usage", # printing out number of registers # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster - "-lineinfo", + "-lineinfo", # TODO: disable this for release to reduce binary size "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index be27f14f624..4d20ff8af2b 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -117,7 +117,7 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] for dv in dv_vals: @@ -336,7 +336,7 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] for dv in dv_vals: @@ -564,12 +564,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [True]) -# @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(False, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) # @pytest.mark.parametrize("rotary_interleaved", [True]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) @@ -617,6 +618,7 @@ def test_flash_attn_kvcache( page_size, rotary_fraction, rotary_interleaved, + has_rotary_seqlens, seqlen_new_eq_seqlen_q, causal, local, @@ -630,6 +632,8 @@ def test_flash_attn_kvcache( pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -643,11 +647,11 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] for dv in dv_vals: - has_qv = d == 64 and dv == 512 + has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) @@ -733,6 +737,7 @@ def test_flash_attn_kvcache( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 if rotary_dim > 0: angle = ( torch.rand( @@ -747,7 +752,7 @@ def test_flash_attn_kvcache( sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( @@ -755,7 +760,7 @@ def test_flash_attn_kvcache( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, - seqlen_offsets=cache_seqlens, + seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", @@ -763,7 +768,7 @@ def test_flash_attn_kvcache( ) # q_ro = q k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None @@ -828,12 +833,6 @@ def test_flash_attn_kvcache( num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): - if page_size is None: - k_cache.copy_(k_cache_saved) - v_cache.copy_(v_cache_saved) - else: - k_cache_paged.copy_(k_cache_saved) - v_cache_paged.copy_(v_cache_saved) if precompute_metadata: scheduler_metadata = get_scheduler_metadata( batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, @@ -845,89 +844,98 @@ def test_flash_attn_kvcache( ) else: scheduler_metadata = None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - qv=qv if not varlen_q else qv_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - scheduler_metadata=scheduler_metadata, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index f713242721e..53651d5c848 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -107,9 +107,9 @@ class SingleTileScheduler { } if constexpr (Varlen && Split) { int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + is_valid_tile &= work_info.split_idx < num_splits_dynamic; // Use the top 16 bits to store num_splits work_info.split_idx |= (num_splits_dynamic << 16); - is_valid_tile &= work_info.split_idx < num_splits_dynamic; } work_info.bidb = is_valid_tile ? work_info.bidb : -1; return work_info; @@ -320,7 +320,7 @@ class DynamicPersistentTileScheduler { void init_consumer() const { if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty } } @@ -339,16 +339,16 @@ class DynamicPersistentTileScheduler { if constexpr (IsProducerWarp) { // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % NumProducerThreads == 0) { *tile_count_smem = current_work.tile_idx; } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return {new_tile_idx}; } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int tile_idx = *tile_count_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return {tile_idx}; } } @@ -388,7 +388,7 @@ class VarlenDynamicPersistentTileScheduler { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); - assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, @@ -550,7 +550,7 @@ class VarlenDynamicPersistentTileScheduler { if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { return get_next_work(params, {0, 0, 0, 0}); @@ -580,16 +580,16 @@ class VarlenDynamicPersistentTileScheduler { int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int4 work_info = *work_info_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; } } diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 2c440c6e210..4414b53ac2d 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -12,13 +12,18 @@ constexpr std::tuple tile_size_fwd_sm90( bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { - bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 - // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; - return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; + if (headdim_v == 512) { + return {64, 64, false, false}; + } else if (headdim_v == 256) { + return {128, 112, true, false}; + } else { + // Switch to tile size 192 x 192 for now + bool const use_blockN_128 = is_causal || is_local; + return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; + } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 6c524f9ed3b..30a16078507 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -22,6 +22,7 @@ FA3_UNAVAILABLE_REASON = None FA3_AVAILABLE = True except ImportError as e: + raise e FA3_UNAVAILABLE_REASON = str(e) FA3_AVAILABLE = False @@ -262,7 +263,7 @@ def flash_attn_varlen_func( block_table, None, # kv_batch_idx None, # leftpad_k - None, None, # rotary_cos, rotary_sin + None, None, None, # rotary_cos, rotary_sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal, @@ -448,7 +449,7 @@ def flash_attn_with_kvcache( block_table, cache_batch_idx, # kv_batch_idx None, # leftpad_k - None, None, # rotary_cos, rotary_sin + None, None, None, # rotary_cos, rotary_sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal,