diff --git a/CMakeLists.txt b/CMakeLists.txt index b5d589bbf83..e4423efcc68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.26) -project(vllm_flash_attn LANGUAGES CXX) +project(vllm_flash_attn LANGUAGES CXX CUDA) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) @@ -213,7 +213,9 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) SRCS "${FA3_GEN_SRCS}" CUDA_ARCHS "${FA3_ARCHS}") set_gencode_flags_for_srcs( - SRCS "hopper/flash_fwd_combine.cu" + SRCS + hopper/flash_fwd_combine.cu + hopper/flash_prepare_scheduler.cu CUDA_ARCHS "${FA3_ARCHS}") endif() @@ -223,6 +225,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) LANGUAGE ${VLLM_GPU_LANG} SOURCES hopper/flash_fwd_combine.cu + hopper/flash_prepare_scheduler.cu hopper/flash_api.cpp hopper/flash_api_torch_lib.cpp ${FA3_GEN_SRCS} diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 8f42f0ae100..50af5f63073 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -118,7 +118,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM; diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index 016a010709f..e4875fe3a11 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -79,7 +79,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; @@ -205,7 +205,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 51e207a79ec..75128281541 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -424,7 +424,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -942,7 +942,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); @@ -1002,7 +1002,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index addffe1f185..0d122aa0883 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -15,6 +15,19 @@ import triton import triton.language as tl +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs def layer_norm_ref( x, @@ -126,14 +139,7 @@ def rms_norm_ref( @triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], + configs=triton_autotune_configs(), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @@ -393,14 +399,7 @@ def _layer_norm_fwd( @triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], + configs=triton_autotune_configs(), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 36f0bf6d036..33e5d282716 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -1,6 +1,7 @@ from collections import namedtuple from functools import partial import math +import os from typing import NamedTuple import torch import torch.nn as nn @@ -34,6 +35,8 @@ triton_attention = None triton_attention = None +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" + def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # # Warmup @@ -53,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) # # return time_f[1].mean # return time_f[1] - return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3) def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): @@ -250,6 +253,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128, 192, 256]: for headdim in [128]: nheads = dim // headdim + # nheads = 128 # headdim = 64 # batch_size = 64 # seqlen = 512 @@ -257,14 +261,16 @@ def run(*args, **kwargs): # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 + # nheads_kv = 1 headdim_v = headdim - # headdim_v = 128 + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False for batch_size, seqlen in bs_seqlen_vals: num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) - sink_token_length = 0 pack_gqa = None # seqlen_q = 64 seqlen_q = seqlen @@ -276,6 +282,7 @@ def run(*args, **kwargs): q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) @@ -303,7 +310,7 @@ def run(*args, **kwargs): for causal in [False, True]: # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") - nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: @@ -351,17 +358,17 @@ def run(*args, **kwargs): time.sleep(1) if not varlen: - # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') - m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean - if dtype != torch.float8_e4m3fn and headdim == headdim_v: + if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: time.sleep(1) if not varlen: - _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') else: _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, @@ -387,7 +394,7 @@ def run(*args, **kwargs): print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') - if dtype != torch.float8_e4m3fn and headdim == headdim_v: + if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') @@ -397,7 +404,8 @@ def run(*args, **kwargs): # import pickle # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp: - # with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp: + # with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp: + # # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp: # pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py new file mode 100644 index 00000000000..9b7c0570844 --- /dev/null +++ b/hopper/benchmark_mla_decode.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, Ted Zadouri, Tri Dao. + +# We recommend locking GPU clocks before running the benchmark to ensure consistent results. +# This can be done using the following commands (1830 MHz is the clock for H100): +# sudo nvidia-smi -i 0 -pm 1 +# sudo nvidia-smi -i 0 --lock-gpu-clocks 1830,1830 +# See more here: https://github.com/triton-lang/triton/blob/d9f10ebdc5da53f73eb852fde73d8d7d80b679d1/python/triton/testing.py#L487 + +import time +import torch +import torch.nn.functional as F + +from triton.testing import do_bench, do_bench_cudagraph + +from einops import rearrange + +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + +try: + from flash_mla import flash_mla_with_kvcache, get_mla_metadata +except ImportError: + flash_mla_with_kvcache, get_mla_metadata = None, None + +try: + from flash_attn.utils.benchmark import pytorch_profiler +except ImportError: + pytorch_profiler = None + + +device = "cuda" +dtype = torch.bfloat16 +seqlen = 8192 +seqlen_q = 1 +# nheads_q = 16 +nheads_q = 128 + +use_bench_cudagraph = False + +attn_variants = ["mha", "gqa", "mqa", "mla"] +for attn_variant in attn_variants: +# for attn_variant in attn_variants[3:]: + nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else 1) + headdim = 64 if attn_variant == "mla" else 128 + headdim_v = 512 if attn_variant == "mla" else headdim + has_qv = headdim == 64 and headdim_v == 512 + # page_size = None + page_size = 64 if attn_variant == "mla" else 128 + + should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None + + torch.manual_seed(0) + + batch_size = 128 + cache_seqlens = None + # cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) + # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) + # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) + + print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") + + for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: + # for seqlen in [s * 1024 for s in [1]]: + cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + num_splits = 0 + q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) + try: + v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) + k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) + if page_size is not None: + assert seqlen % page_size == 0 + k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + except torch.OutOfMemoryError: + continue + qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None + + # Precomputing this saves ~2us + scheduler_metadata = get_scheduler_metadata( + batch_size, seqlen_q, seqlen, nheads_q, nheads_kv, headdim, + cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True + ) + # scheduler_metadata = None + fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata) + time.sleep(1) # to avoid power throttling + # Time in ms + if not use_bench_cudagraph: + t0 = do_bench(fn0, warmup=1, rep=10) + else: + torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready + with torch.cuda.stream(torch.cuda.Stream()): + t0 = do_bench_cudagraph(fn0, rep=10) + # exit(0) + if should_run_flashmla: + # Separate out the preprocessing since this can be done once and reused for all layers + mla_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) + q_concat = torch.concat([q, qv], dim=-1) if has_qv else q + kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) + fn1 = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *mla_metadata, causal=True) + time.sleep(1) # to avoid power throttling + if not use_bench_cudagraph: + t1 = do_bench(fn1, warmup=1, rep=10) + else: + torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready + with torch.cuda.stream(torch.cuda.Stream()): + t1 = do_bench_cudagraph(fn1, rep=10) + + total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output + flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 + ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 + ideal_h100_time_flop = flops / 989e12 * 1e6 + ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) + print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.1f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + if should_run_flashmla: + print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.1f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") + print(f"Arithmetic intensity: {flops / mem_io:.1f}") + print(f"Ideal time: {ideal_h100_time:.0f} us") + + # if pytorch_profiler is not None: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn0) + # if should_run_flashmla: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn1) diff --git a/hopper/block.h b/hopper/block.h index d06744c3b32..eda7eaa1c40 100644 --- a/hopper/block.h +++ b/hopper/block.h @@ -35,9 +35,14 @@ struct BlockMN { } // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } if constexpr (Split) { - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); - n_block_min = n_block_min + split_idx * num_n_blocks_per_split; + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + int split_idx_actual = split_idx & 0x0000FFFF; + int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); + n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } } // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } return {n_block_min, n_block_max}; diff --git a/hopper/cuda_check.h b/hopper/cuda_check.h new file mode 100644 index 00000000000..b5e63aef79d --- /dev/null +++ b/hopper/cuda_check.h @@ -0,0 +1,19 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index 811d0d1f16e..9362b040453 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -4,8 +4,8 @@ #pragma once -#include -#include +#include "cutlass/cutlass.h" +#include "cutlass/barrier.h" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 1c13988ebd7..69102e8c4e6 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -21,21 +21,24 @@ namespace flash { using namespace cute; template + int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false> struct CollectiveEpilogueFwd { using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; + using ElementPartial = float; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool PackGQA = PackGQA_; - static constexpr bool Use_smem = sizeof(Element) <= 2; - static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && Use_smem && !PackGQA; + static constexpr bool Split = Split_; + static constexpr bool Use_smem = !(Split && !Varlen); + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); + static_assert(sizeof(Element) <= 2); static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); @@ -52,8 +55,6 @@ struct CollectiveEpilogueFwd { // we need to call divmod. static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - // static constexpr int kBlockKGmem = kHeadDimV % 128 == 0 ? 128 : (kHeadDimV % 64 == 0 ? 64 : 32); - // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDimV / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); @@ -121,8 +122,12 @@ struct CollectiveEpilogueFwd { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; float* ptr_LSE; StrideLSE const stride_LSE; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; int32_t const nheads_kv; int const* cu_seqlens = nullptr; int const* seqused = nullptr; @@ -135,10 +140,16 @@ struct CollectiveEpilogueFwd { StrideO const stride_O; ShapeOPacked const shape_O_packed; StrideOPacked const stride_O_packed; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + StrideOPacked const stride_O_partial_packed; float* ptr_LSE; StrideLSE const stride_LSE; ShapeLSEPacked const shape_LSE_packed; StrideLSEPacked const stride_LSE_packed; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + StrideLSEPacked const stride_LSE_partial_packed; cutlass::FastDivmod qhead_per_khead_divmod; TMA_O tma_store_O; int const* cu_seqlens = nullptr; @@ -165,6 +176,10 @@ struct CollectiveEpilogueFwd { args.stride_O, make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) ); + auto const stride_O_partial_packed = cute::conditional_return( + args.stride_O_partial, + make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) + ); // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) auto const shape_LSE_packed = cute::conditional_return( select<0, 2, 3, 4>(args.shape_O), @@ -174,8 +189,14 @@ struct CollectiveEpilogueFwd { args.stride_LSE, make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) ); + auto const stride_LSE_partial_packed = cute::conditional_return( + args.stride_LSE_partial, + make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) + ); return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, + args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, + args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, cutlass::FastDivmod(qhead_per_khead), tma_store_O, args.cu_seqlens, args.seqused}; } @@ -191,7 +212,7 @@ struct CollectiveEpilogueFwd { template CUTLASS_DEVICE void store(Params const& params, - FrgTensorO const& tOrO, + FrgTensorO& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, @@ -200,12 +221,25 @@ struct CollectiveEpilogueFwd { ) { auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); + static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); + // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. + // Otherwise we can permute after conversion. + if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } Tensor tOrO_out = make_tensor_like(tOrO); flash::convert_type_out(tOrO, tOrO_out); - if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } + if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } // Make sure all WGs have finished reading V // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that @@ -253,9 +287,12 @@ struct CollectiveEpilogueFwd { Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } if (!LargeHeadDimV || warp_group_idx == 0) { if constexpr (!PackGQA) { @@ -265,7 +302,7 @@ struct CollectiveEpilogueFwd { if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } } } else { - PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } @@ -291,10 +328,10 @@ struct CollectiveEpilogueFwd { } } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } - if constexpr (Use_smem) { + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) @@ -321,17 +358,27 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } else { - // We already arrived on barrier_O earlier + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // We already arrived on barrier_O earlier if !Use_smem + if constexpr (Use_smem) { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } if constexpr (!PackGQA) { static constexpr int kGmemElemsPerStoreDirect = 2; - cute::Copy_Atom, Element> gmem_copy_direct; + cute::Copy_Atom, ElementPartial> gmem_copy_direct; // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); - Tensor tOgO = thread_mma.partition_C(gO); + Tensor tOgO = thread_mma.partition_C(gOpartial); Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); @@ -347,7 +394,7 @@ struct CollectiveEpilogueFwd { } } } else { - PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } } @@ -359,7 +406,6 @@ struct CollectiveEpilogueFwd { } // Write 0 to output and -inf to LSE - template CUTLASS_DEVICE void store_zero( Params const& params, @@ -368,13 +414,23 @@ struct CollectiveEpilogueFwd { ) { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); static_assert(kBlockM <= NumEpilogueThreads); @@ -386,35 +442,39 @@ struct CollectiveEpilogueFwd { if (row < seqlen_o * qhead_per_khead) { int m_idx, h_idx; m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); - // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" + // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; } } } - if constexpr (!Clear_O) { return; } + // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, + // since it will not use the value of O if LSE is -inf. + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - if constexpr (!PackGQA) { - Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_fragment_like(tOgO); - cute::clear(tOrO); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - } else { - // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; - Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); - cute::clear(tOrO); - PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + if constexpr (!PackGQA) { + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + cute::clear(tOrO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); + cute::clear(tOrO); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } } } diff --git a/hopper/flash.h b/hopper/flash.h index 8e95f5ff75c..69562d4881e 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -121,6 +121,7 @@ struct Flash_fwd_params : public Qkv_params { index_t page_table_batch_stride; int page_size; int num_pages; + bool pagedkv_tma; // The dropout probability (probability of keeping an activation). float p_dropout; @@ -133,7 +134,6 @@ struct Flash_fwd_params : public Qkv_params { // Local window size int window_size_left, window_size_right; - int sink_token_length; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; @@ -150,6 +150,10 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; + // int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; + int * __restrict__ num_splits_dynamic_ptr; + bool skip_scheduler_metadata_computation; int arch; int num_sm; @@ -203,9 +207,10 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index b7d56ead48b..543a60ea5c4 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -14,6 +14,7 @@ #include "static_switch.h" #include "tile_size.h" #include "heuristics.h" +#include "cuda_check.h" // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 // This is so that we can pass in torch.dtype as a parameter to the function. @@ -262,70 +263,70 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { TORCH_CHECK(params.num_splits >= 1); ARCH_SWITCH(params.arch, Arch, [&] { SPLIT_SWITCH(params.num_splits > 1, Split, [&] { - PAGEDKV_SWITCH(params.page_table, PagedKV, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { - // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation - static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV || Split; + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); @@ -334,25 +335,25 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } else { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP8."); @@ -365,27 +366,27 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { #ifndef FLASHATTENTION_DISABLE_SPLIT // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } else if (params.is_bf16) { if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } else { if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } #else @@ -393,17 +394,28 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif } +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + inline bool get_pack_gqa(Flash_fwd_params const& params) { - // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation and binary size. + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. // Has little effect on speed. - if (params.arch < 90 || params.page_table || params.num_splits > 1) { return true; } + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } #ifdef FLASHATTENTION_DISABLE_PACKGQA return false; #else // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -417,7 +429,7 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); @@ -432,9 +444,11 @@ inline int get_num_splits(Flash_fwd_params const& params) { int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); // Always enable PackGQA for Split - return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); - // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, - // params.num_sm, num_n_blocks, 128, params.d_rounded); + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } @@ -476,6 +490,127 @@ inline int round_up_headdim(int head_size) { return 256; } +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +at::Tensor +mha_fwd_get_scheduler_metadata( + int batch_size, + int max_seqlen_q, + int max_seqlen_k, + int num_heads, + int num_heads_k, + int headdim, + int headdim_v, + at::ScalarType qkv_dtype, + const at::Tensor &seqused_k, // b + std::optional &cu_seqlens_q_, // b+1 + std::optional &cu_seqlens_k_, // b+1 + std::optional &cu_seqlens_k_new_, // b+1 + std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional &leftpad_k_, // b + std::optional page_size, + int max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int window_size_left, + int window_size_right, + bool has_softcap, + int num_splits, + std::optional pack_gqa_, + int const sm_margin + ) { + + TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = round_up_headdim(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr() : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr() : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr() : nullptr; + params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr; + params.seqused_k = seqused_k.data_ptr(); + params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr() : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_dynamic_split = params.b <= 992; + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + + auto opts = seqused_k.options(); + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + if (scheduler_needs_semaphore || use_dynamic_split) { + tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + if (scheduler_needs_semaphore) { + if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + } else { + params.tile_count_semaphore = nullptr; + } + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + } + + if (params.num_splits_dynamic_ptr) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + // b: batch_size // b_k: batch_size_k // s_q: seqlen_q @@ -512,9 +647,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq bool is_causal, int window_size_left, int window_size_right, - int sink_token_length, float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional &scheduler_metadata_, // (b + 1) int num_splits, std::optional pack_gqa_, int const sm_margin @@ -567,11 +702,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); } - // This is what we will template on - bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); - #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); - #endif auto const sizes = q.sizes(); const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; @@ -594,10 +724,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { - TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " - "or (Q/K <= 64 and V <= 512)."); + "or (Q/K <= 64 and V <= 512)."); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, @@ -609,6 +739,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // TODO: check this if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } if (is_causal) { window_size_right = 0; } // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM. @@ -648,6 +785,19 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_SHAPE(seqused_k, batch_size); } + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); @@ -709,11 +859,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq sm_margin); params.total_q = total_q; params.total_k = total_k; - params.sink_token_length = sink_token_length; params.b_k = batch_size_k; params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; - + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } if (paged_KV) { params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); @@ -721,11 +872,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - - if (k_new_.has_value()) { + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma at::Tensor k_new, v_new; TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); @@ -773,6 +920,42 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel + bool const use_dynamic_split = is_varlen && params.b <= 992; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + if (scheduler_needs_semaphore || use_dynamic_split) { + int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + at::Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); + } + if (scheduler_needs_semaphore && !use_dynamic_split) { + tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing + } + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + } + if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, @@ -796,14 +979,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - params.leftpad_k = static_cast(leftpad_k.data_ptr()); - } - if (rotary_cos_.has_value()) { TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); @@ -860,18 +1035,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } - at::Tensor tile_count_semaphore; - // We don't use the persistent scheduler if Split and not Varlen - bool const persistent_scheduler = params.arch >= 90 - ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) - : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (persistent_scheduler) { - tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32)); - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - } else { - params.tile_count_semaphore = nullptr; - } - if (q_type == at::ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); @@ -915,10 +1078,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA - TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV - TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV."); + TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); @@ -934,12 +1097,14 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. // if (is_varlen_q && !seqused_q_.has_value()) { - if (is_varlen_q) { - params.b = 1; - params.seqlen_q = total_q; - } - run_mha_fwd_combine(params, stream); + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -1029,7 +1194,6 @@ std::vector mha_bwd( bool is_causal, int window_size_left, int window_size_right, - int const sink_token_length, float const softcap, bool const deterministic, int const sm_margin) { @@ -1263,7 +1427,7 @@ std::vector mha_bwd( params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); - params.sink_token_length = sink_token_length; + params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -1389,10 +1553,11 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.o_row_stride = out.stride(1); params.o_head_stride = out.stride(2); params.o_batch_stride = out.stride(0); + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; if (seqlen > 0 && batch_size > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd_combine(params, stream); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } at::Tensor out_padded = out; @@ -1413,6 +1578,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fwd", &mha_fwd, "Forward pass"); m.def("bwd", &mha_bwd, "Backward pass"); m.def("fwd_combine", &mha_combine, "Combine partial attention outputs"); + m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); } #endif \ No newline at end of file diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index 2406d1a5076..f3f6a18b21b 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -9,6 +9,14 @@ * Externs for the flash_attn ops to be exposed as a pytorch library */ +// b: batch_size +// b_k: batch_size_k +// s_q: seqlen_q +// s_k: seqlen_k +// s_k_new: seqlen_k_new +// h: num_heads +// h_k: num_heads_k +// d: head_size std::vector mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. @@ -37,12 +45,41 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq bool is_causal, int window_size_left, int window_size_right, - int sink_token_length, float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional &scheduler_metadata_, // (b + 1) int num_splits, std::optional pack_gqa_, - int const sm_margin); + int const sm_margin +); + +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +at::Tensor +mha_fwd_get_scheduler_metadata( + int batch_size, + int max_seqlen_q, + int max_seqlen_k, + int num_heads, + int num_heads_k, + int headdim, + int headdim_v, + at::ScalarType qkv_dtype, + const at::Tensor &seqused_k, // b + std::optional &cu_seqlens_q_, // b+1 + std::optional &cu_seqlens_k_, // b+1 + std::optional &cu_seqlens_k_new_, // b+1 + std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional &leftpad_k_, // b + std::optional page_size, + int max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int window_size_left, + int window_size_right, + bool has_softcap, + int num_splits, + std::optional pack_gqa_, + int const sm_margin +); /** * Torch Library Registration @@ -74,13 +111,40 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " bool is_causal," " int window_size_left," " int window_size_right," - " int sink_token_length," " float softcap," " bool is_rotary_interleaved," + " Tensor? scheduler_metadata," " int num_splits," " bool? pack_gqa," " int sm_margin) -> Tensor[]"); ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); + + ops.def("get_scheduler_metadata(" + " int batch_size," + " int max_seqlen_q," + " int max_seqlen_k," + " int num_heads," + " int num_heads_k," + " int headdim," + " int headdim_v," + " ScalarType qkv_dtype," + " Tensor seqused_k," + " Tensor? cu_seqlens_q," + " Tensor? cu_seqlens_k," + " Tensor? cu_seqlens_k_new," + " Tensor? seqused_q," + " Tensor? leftpad_k," + " int? page_size," + " int max_seqlen_k_new," // 0 means we're not appending new KV + " bool is_causal," + " int window_size_left," + " int window_size_right," + " bool has_softcap," + " int num_splits," + " bool? pack_gqa," + " int sm_margin) -> Tensor"); + ops.impl("get_scheduler_metadata", torch::kCUDA, + make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME); \ No newline at end of file diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 78cfe1cb906..92b84096f02 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -42,13 +42,12 @@ def _flash_attn_forward( softmax_scale, causal, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, rotary_interleaved=True, + scheduler_metadata=None, num_splits=1, pack_gqa=None, sm_margin=0): - assert sink_token_length == 0, "sink_token_length not supported yet" q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -86,14 +85,14 @@ def _flash_attn_forward( causal, window_size[0], window_size[1], - sink_token_length, softcap, rotary_interleaved, + scheduler_metadata, num_splits, pack_gqa, sm_margin, ) - return (out, softmax_lse, *rest) + return out, softmax_lse, *rest def _flash_attn_backward( @@ -115,12 +114,10 @@ def _flash_attn_backward( softmax_scale, causal, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, sm_margin=0, ): - assert sink_token_length == 0, "sink_token_length not supported yet" # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( @@ -143,7 +140,6 @@ def _flash_attn_backward( causal, window_size[0], window_size[1], - sink_token_length, softcap, deterministic, sm_margin, @@ -160,7 +156,6 @@ def forward( causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, num_heads_q=None, @@ -183,14 +178,13 @@ def forward( softmax_scale, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, sink_token_length=sink_token_length, + window_size=window_size, softcap=softcap, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.ndim = qkv.dim() @@ -223,7 +217,6 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ) @@ -244,7 +237,6 @@ def forward( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -270,7 +262,6 @@ def forward( softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -281,7 +272,6 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin @@ -307,7 +297,6 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -337,7 +326,6 @@ def forward( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -367,7 +355,6 @@ def forward( softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -380,7 +367,6 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin @@ -409,7 +395,6 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -426,7 +411,6 @@ def flash_attn_qkvpacked_func( causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, num_heads_q=None, @@ -471,7 +455,6 @@ def flash_attn_qkvpacked_func( causal, q_descale, k_descale, v_descale, window_size, - sink_token_length, softcap, deterministic, num_heads_q, @@ -487,7 +470,6 @@ def flash_attn_func( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -548,7 +530,6 @@ def flash_attn_func( qv, q_descale, k_descale, v_descale, window_size, - sink_token_length, softcap, num_splits, pack_gqa, @@ -572,7 +553,6 @@ def flash_attn_varlen_func( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -594,7 +574,6 @@ def flash_attn_varlen_func( qv, q_descale, k_descale, v_descale, window_size, - sink_token_length, softcap, num_splits, pack_gqa, @@ -629,9 +608,9 @@ def flash_attn_with_kvcache( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window - sink_token_length=0, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, + scheduler_metadata=None, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication @@ -722,7 +701,6 @@ def flash_attn_with_kvcache( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ - assert sink_token_length == 0 assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: @@ -756,12 +734,53 @@ def flash_attn_with_kvcache( softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out + + +def get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication +): + cache_seqlens = maybe_contiguous(cache_seqlens) + if headdim_v is None: + headdim_v = headdim + scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, + qkv_dtype, + cache_seqlens, + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_leftpad, + page_size, + max_seqlen_k_new, + causal, + window_size[0], window_size[1], + has_softcap, + num_splits, + pack_gqa, + sm_margin, + ) + return scheduler_metadata diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 635228eebcf..76ded0407ec 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -120,7 +120,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.dsoftmax_sum), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum params.scale_softmax, - params.window_size_left, params.window_size_right, params.sink_token_length, + params.window_size_left, params.window_size_right, params.softcap, params.b, params.dq_semaphore, @@ -165,7 +165,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { num_blocks_n, params.h, params.b, 1 /*num_splits*/, params.h / params.h_k, params.seqlen_k, - params.seqlen_q, params.d, sizeof(Element), + params.seqlen_q, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k }; diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index a1725cf2a82..3e85a0a212c 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -3,11 +3,11 @@ #include "flash_fwd_combine_launch_template.h" -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 8957ae41a42..a22e05969d9 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -12,6 +12,8 @@ #include #include +#include "cutlass/arch/grid_dependency_control.h" + #include "seqlen.h" #include "utils.h" @@ -128,38 +130,41 @@ class FlashAttnFwdCombine { static constexpr int SharedStorageSize = sizeof(SharedStorage); - // Device side arguments struct Arguments { - ElementPartial const* ptr_O_partial; + ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; - float const* ptr_LSE_partial; + float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; - Element* ptr_O; + Element* const ptr_O; StrideO const stride_O; - float* ptr_LSE; + float* const ptr_LSE; StrideLSE const stride_LSE; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; }; // Kernel entry point API struct Params { - ElementPartial const* ptr_O_partial; + ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; - float const* ptr_LSE_partial; + float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; - Element* ptr_O; + Element* const ptr_O; StrideO const stride_O; - float* ptr_LSE; + float* const ptr_LSE; StrideLSE const stride_LSE; cutlass::FastDivmod seqlen_divmod, head_divmod; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -180,7 +185,9 @@ class FlashAttnFwdCombine { args.stride_LSE, cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), args.cu_seqlens, - args.seqused + args.seqused, + args.num_splits_dynamic_ptr, + args.semaphore_to_reset }; } @@ -196,17 +203,28 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = !Varlen ? 0 : blockIdx.y; - int const num_splits = get<1>(params.shape_LSE_partial); + int const batch = blockIdx.z; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + + if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + cutlass::arch::wait_on_dependent_grids(); + *params.semaphore_to_reset = 0; + } + if (num_splits <= 1) { return; } flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; - int max_idx = seqlen * get<2>(params.shape_LSE_partial) * get<3>(params.shape_LSE_partial); + int max_idx = seqlen * get<2>(params.shape_LSE_partial); + if constexpr (Varlen) { + if (m_block * kBlockM >= max_idx) { return; } + } cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); // Step 1: load LSE_partial from gmem -> smem - Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), select<1, 0, 2, 3>(params.shape_LSE_partial), select<1, 0, 2, 3>(params.stride_LSE_partial)); // (num_splits, seqlen, head, batch) + Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), + select<1, 0, 2, 3>(params.shape_LSE_partial), + select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); GmemTiledCopyLSE gmem_tiled_copy_LSE; auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); @@ -217,19 +235,20 @@ class FlashAttnFwdCombine { // Repeat the partitioning with identity layouts Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); + cutlass::arch::wait_on_dependent_grids(); + #pragma unroll for (int m = 0; m < size<2>(tLSEcLSE); ++m) { int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { - int m_idx, bidh, bidb; + int m_idx, bidh; if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; } - Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh, bidb); + Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); #pragma unroll for (int s = 0; s < size<1>(tLSEcLSE); ++s) { int si = get<0>(tLSEcLSE(_0{}, s, _0{})); @@ -256,26 +275,24 @@ class FlashAttnFwdCombine { // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), - params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); - Tensor tObidb = make_tensor(make_shape(size<1>(tOcO))); Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { int mi = get<0>(tOcO(_0{}, m, _0{})); int idx = m_block * kBlockM + mi; if constexpr (!Varlen) { - tObidb[m] = params.head_divmod.divmod(tObidh(m), params.seqlen_divmod.divmod(tOmidx(m), idx)); + tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); } else { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); - tObidb[m] = 0; } - tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m), tObidb(m)); + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); if (idx >= max_idx) { - tObidb[m] = -1; + tObidh[m] = -1; } } @@ -291,8 +308,8 @@ class FlashAttnFwdCombine { Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidb(m) >= 0) { - Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}, _0{}).layout()); + if (tObidh(m) >= 0) { + Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { @@ -372,22 +389,21 @@ class FlashAttnFwdCombine { // Step 5: store final LSE back to gmem if (k_block == 0) { auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { - int m_idx, bidh, bidb; + int m_idx, bidh; if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; } // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh, bidb) = lse_sum(m); + mLSE(m_idx, bidh) = lse_sum(m); } } } @@ -420,7 +436,7 @@ class FlashAttnFwdCombine { #pragma unroll for (int m = 0; m < size<1>(tOrOpartial); ++m) { - if (tObidb(m) >= 0 && scale(m) > 0.f) { + if (tObidh(m) >= 0 && scale(m) > 0.f) { #pragma unroll for (int k = 0; k < size<2>(tOrOpartial); ++k) { if (Is_even_K || tOpO(k)) { @@ -441,19 +457,19 @@ class FlashAttnFwdCombine { flash::convert_type_out(tOrO, rO); auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), - shape_O, params.stride_O); + shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidb(m) >= 0) { + if (tObidh(m) >= 0) { #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (Is_even_K || tOpO(k)) { - cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m), tObidb(m))); + cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); } } } diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index eb7dd404c07..11d422924b4 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -9,6 +9,7 @@ #include "cutlass/cutlass.h" #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 #include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch #include "static_switch.h" #include "flash.h" @@ -16,11 +17,12 @@ using namespace cute; -template -void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { +template +void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; + IsEvenK, Varlen, Element, ElementPartial, ArchTag>; typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), @@ -33,43 +35,46 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); int num_blocks_k = cute::ceil_div(params.dv, kBlockK); - int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_m, num_blocks_k, !Varlen ? 1 : params.b); + int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); + dim3 grid_m(num_blocks_m, num_blocks_k, params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - kernel<<>>(kernel_params); + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); - BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { - if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. - if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream); - return; + ARCH_SWITCH(params.arch, Arch, [&] { + BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream, enable_pdl); + return; + } } - } - if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream); - } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream); - } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream); - } else { - run_flash_fwd_combine(params, stream); - } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else { + run_flash_fwd_combine(params, stream, enable_pdl); + } + }); }); } diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index 71071d72218..4c35da4f08a 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -203,9 +203,7 @@ class FlashAttnFwdSm80 { threadIdx.x, block_coord); } else { // Write 0 to gO and -inf to gLSE. - // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will - // not use the value of O if LSE is -inf. - epilogue.template store_zero(params.epilogue, threadIdx.x, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 9cfb2d9e5d3..962283fe279 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -14,6 +14,8 @@ #include #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/grid_dependency_control.h" + #include "seqlen.h" #include "utils.h" #include "softmax.h" @@ -35,7 +37,6 @@ class FlashAttnFwdSm90 { static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; static constexpr bool Varlen = CollectiveMainloop::Varlen; - static constexpr bool PagedKV = CollectiveMainloop::PagedKV; static constexpr bool Split = CollectiveMainloop::Split; static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; @@ -308,7 +309,7 @@ class FlashAttnFwdSm90 { cutlass::arch::warpgroup_reg_dealloc(); // The pipelines for AppendKV and main attention are different, since e.g. main attention - // might use cp.async to load KV (if PagedKV) while AppendKV always uses TMA to load + // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load // KV_new. Since the pipeline states are different, we have to manually sync to make // sure the two pipelines don't race when accessing smem_k and smem_v. PipelineState smem_pipe_write = cutlass::make_producer_start_state(); @@ -321,6 +322,8 @@ class FlashAttnFwdSm90 { } if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } + cutlass::arch::wait_on_dependent_grids(); + // Load Q, K, V for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); @@ -330,7 +333,7 @@ class FlashAttnFwdSm90 { SeqlenInfo_t seqlen_info{ get<2>(block_coord) /*bidb*/, get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, @@ -371,26 +374,14 @@ class FlashAttnFwdSm90 { CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { - // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); - float softmax_scale_log2 = params.mainloop.softmax_scale_log2; - // If there's tanh softcap, the scaling will be done before tanh. + // get_next_work will be called before the epilogue + ) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); int const bidb = get<2>(block_coord); - if constexpr (Is_FP8 && !Has_softcap) { - int const bidh = get<1>(block_coord); - int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; - float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; - float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; - softmax_scale_log2 *= q_descale * k_descale; - } - flash::Softmax softmax(softmax_scale_log2); - SeqlenInfo_t seqlen_info{ bidb, get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, @@ -411,6 +402,18 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } + // If there's tanh softcap, the scaling will be done before tanh. + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax softmax(softmax_scale_log2); + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); bool tile_valid; if constexpr (!LargeHeadDimV) { tile_valid = mainloop.mma( @@ -427,16 +430,20 @@ class FlashAttnFwdSm90 { tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); } } + // Do this here before the epilogue so that the next tile is ready to go. + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); + if constexpr (Split && Varlen) { + if (!work_tile_info.is_valid(params.scheduler)) { // Last tile + cutlass::arch::launch_dependent_grids(); + } + } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. - // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will - // not use the value of O if LSE is -inf. - epilogue.template store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); - // epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); } } epilogue.store_tail(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 15f43929627..00692049366 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -10,6 +10,7 @@ #include "cutlass/device_kernel.h" // For device_kernel #include #include "cutlass/cluster_launch.hpp" +#include "cutlass/kernel_launch.h" #include "static_switch.h" #include "flash.h" @@ -24,7 +25,7 @@ using namespace cute; template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); @@ -35,8 +36,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); @@ -50,10 +51,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t= 90 ? (Split && !Varlen) : !((Is_causal && !Varlen) || (Varlen && Split)), SchedulerSingleTile, SchedulerPersistent>; + static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split)); + using Scheduler = std::conditional_t; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90_or_later>, @@ -90,8 +92,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {seqlen_q, params.d, params.h, batch_q}, // shape_Q {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q static_cast(params.k_ptr), - {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, - params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K + {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), params.dv, // headdim_v @@ -111,14 +113,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.is_rotary_interleaved, params.page_table, // if page_size is not set, avoid dividing by zero - {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table + {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table {params.page_table_batch_stride, _1{}}, // stride_page_table params.scale_softmax, params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, {params.q_descale_batch_stride, params.q_descale_head_stride}, {params.k_descale_batch_stride, params.k_descale_head_stride}, {params.v_descale_batch_stride, params.v_descale_head_stride}, - params.window_size_left, params.window_size_right, params.sink_token_length, + params.window_size_left, params.window_size_right, params.softcap, params.num_splits, params.kv_batch_idx, @@ -127,15 +129,15 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.leftpad_k, }; typename CollectiveEpilogue::Arguments epilogue_args { - static_cast(!Split ? params.o_ptr : params.oaccum_ptr), + static_cast(params.o_ptr), {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O - {!Split ? params.o_row_stride : params.oaccum_row_stride, - _1{}, - !Split ? params.o_head_stride : params.oaccum_head_stride, - !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0, - !Split ? 0 : params.oaccum_split_stride}, // stride_O - static_cast(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O + static_cast(params.oaccum_ptr), + {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE + static_cast(params.softmax_lseaccum_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial params.h_k, params.cu_seqlens_q, params.seqused_q }; @@ -147,10 +149,17 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, sizeof(Element), - params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q + params.seqlen_k, params.d, params.dv, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, + // params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, }; + if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + int device; CHECK_CUDA(cudaGetDevice(&device)); typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ @@ -178,32 +187,33 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - kernel<<>>(kernel_params); + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, + Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; - using T_out = std::conditional_t, float>; + using T_out = std::conditional_t; CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu new file mode 100644 index 00000000000..7093fff32b6 --- /dev/null +++ b/hopper/flash_prepare_scheduler.cu @@ -0,0 +1,124 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include "cutlass/fast_math.h" +#include "cutlass/barrier.h" +#include "cutlass/arch/barrier.h" + +#include "cutlass/arch/grid_dependency_control.h" + +#include "flash.h" + +namespace flash { + +__global__ void prepare_varlen_num_blocks_kernel( + int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, + int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, + cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, + int* const tile_count_semaphore, + // int* const num_m_blocks_ptr, + int* const num_splits_dynamic_ptr, + bool enable_pdl) { + + static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; + static constexpr int kSmemSize = 1; + // Assume that there's only one block in the grid + __shared__ int total_blocks_smem[kSmemSize]; + + // There's only 1 block in the grid, so might as well start launching the main attn kernel + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } + + if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } + __syncthreads(); + + if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } + + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + + auto get_num_m_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen; + if (seqused_q) { + seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; + } else if (cu_seqlens_q) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_q_static; + } + seqlen *= qhead_per_khead; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; + }; + + auto get_num_n_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; + int seqlen; + if (seqused_k) { + seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; + } else if (cu_seqlens_k) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_k_static; + } + int seqlen_new; + if (cu_seqlens_k_new) { + int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; + int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; + } else { + seqlen_new = seqlen_k_new_static; + } + // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } + seqlen = seqlen - leftpad_k + seqlen_new; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; + }; + + int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; + int bidb_start = kNumBatchPerWarp * warp_idx; + int num_m_blocks = get_num_m_blocks(bidb_start); + int num_n_blocks = get_num_n_blocks(bidb_start); + + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } +} + +} // flash + +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, + int blockM, int blockN, bool enable_pdl) { + // Only support batch <= 992 (32 warps, each with 31 batches) + int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + // params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, enable_pdl); +} diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 03fd391ff79..031ea44a0b3 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -22,11 +22,11 @@ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, in // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. -inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { +inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split // However, in the case of super long seqlen where each head of KV doesn't even fit into - // L2 (we assume conservatively that L2 size is 50MB), we want to split. - if (batch_nheads_mblocks >= 0.8f * num_SMs) { + // L2 (we assume that L2 size is 50MB), we want to split. + if (total_mblocks >= 0.8f * num_SMs) { int const size_l2 = 50 * 1024 * 1024; // Only split if there are enough queries to go over the KV at least twice // Don't split if causal @@ -43,7 +43,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n std::vector efficiency; efficiency.reserve(max_splits); for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float n_waves = float(total_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } diff --git a/hopper/mainloop_bwd_sm80.hpp b/hopper/mainloop_bwd_sm80.hpp index eb0503c9373..0a79670f475 100644 --- a/hopper/mainloop_bwd_sm80.hpp +++ b/hopper/mainloop_bwd_sm80.hpp @@ -296,7 +296,7 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -328,7 +328,7 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int *const dq_semaphore; @@ -359,7 +359,7 @@ struct CollectiveMainloopBwdSm80 { args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -385,7 +385,7 @@ struct CollectiveMainloopBwdSm80 { }; auto m_block_min_max = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, - params.window_size_left, params.window_size_right, params.sink_token_length); + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); int const m_block_min = get<0>(m_block_min_max); int const m_block_max = get<1>(m_block_min_max); // It's possible to have m_block_max <= m_block_min. Exit early @@ -532,7 +532,7 @@ struct CollectiveMainloopBwdSm80 { int const seqlen_k = seqlen_info.seqlen_k; flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index e3b2960685a..71cfb020469 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -310,7 +310,7 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -337,7 +337,7 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -394,7 +394,7 @@ struct CollectiveMainloopBwdSm90 { args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -428,7 +428,7 @@ struct CollectiveMainloopBwdSm90 { }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, - params.window_size_left, params.window_size_right, params.sink_token_length); + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { @@ -596,7 +596,7 @@ struct CollectiveMainloopBwdSm90 { }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, - params.window_size_right, params.sink_token_length); + params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return; } @@ -686,7 +686,7 @@ struct CollectiveMainloopBwdSm90 { }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, - params.window_size_right, params.sink_token_length); + params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } @@ -792,7 +792,7 @@ struct CollectiveMainloopBwdSm90 { // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 909654d3426..a642fc74f9c 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -202,7 +202,7 @@ struct CollectiveMainloopFwdSm80 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + int const window_size_left = -1, window_size_right = -1; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -247,7 +247,7 @@ struct CollectiveMainloopFwdSm80 { float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -291,7 +291,7 @@ struct CollectiveMainloopFwdSm80 { args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, @@ -415,7 +415,10 @@ struct CollectiveMainloopFwdSm80 { params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { @@ -541,7 +544,7 @@ struct CollectiveMainloopFwdSm80 { if constexpr (!Share_QV_Smem) { preprocess_Q(); } flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); @@ -640,12 +643,6 @@ struct CollectiveMainloopFwdSm80 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - // } } float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); @@ -704,8 +701,11 @@ struct CollectiveMainloopFwdSm80 { params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); static_assert(std::is_same_v); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 3589534c15f..6a21078f77a 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -29,7 +29,7 @@ namespace flash { using namespace cute; template struct CollectiveMainloopFwdSm90 { @@ -46,7 +46,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - static constexpr bool PagedKV = PagedKV_; + static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; static constexpr bool AppendKV = AppendKV_; static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; @@ -54,7 +54,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool V_colmajor = V_colmajor_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; static constexpr bool Use_TMA_Q = !PackGQA; - static constexpr bool Use_TMA_KV = !PagedKV; + static constexpr bool Use_TMA_KV = !PagedKVNonTMA; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; @@ -75,16 +75,17 @@ struct CollectiveMainloopFwdSm90 { static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); - static_assert(!(HasQv && !IntraWGOverlap), "HasQv requires IntraWGOverlap"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. static constexpr bool MmaQK_is_RS = false; // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!MmaPV_is_RS && !IntraWGOverlap), "MmaPV must be RS if IntraWGOverlap is disabled"); static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); + // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write + static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; + using AtomLayoutQK = Layout, _1, _1>>; using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< @@ -110,6 +111,10 @@ struct CollectiveMainloopFwdSm90 { using TiledMmaQV = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutQK{})); + // For hdim64,512, WG1 can use RS but WG2 must use SS + using TiledMmaPV_RS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutPV{})); static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); static constexpr int NumMmaThreads = size(TiledMmaPV{}); @@ -130,17 +135,17 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + Int, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + Int, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); using SmemLayoutVMmaQV = decltype(tile_to_shape( SmemLayoutAtomVMmaQV{}, - make_shape(shape<1>(TileShape_MNK_QV{}), shape<2>(TileShape_MNK_QV{}), Int{}))); + make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); // Only used if we're using cp.async to load V @@ -203,7 +208,7 @@ struct CollectiveMainloopFwdSm90 { using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); - // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); @@ -216,7 +221,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using GmemLayoutAtom = Layout, Int>, @@ -348,14 +353,14 @@ struct CollectiveMainloopFwdSm90 { ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) : NumMmaWarpGroups == 2) && !LargeHeadDimV; - static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; // Host side kernel arguments struct Arguments { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; - Element* const ptr_K; // Not Element const* since we might append to KV cache in-place + Element* const ptr_K; // not Element const* since we might append to KV cache in-place ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; @@ -380,7 +385,7 @@ struct CollectiveMainloopFwdSm90 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + int const window_size_left = -1, window_size_right = -1; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -424,6 +429,7 @@ struct CollectiveMainloopFwdSm90 { ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod blockN_per_page_size_divmod; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; TMA_K tma_load_K; @@ -435,7 +441,7 @@ struct CollectiveMainloopFwdSm90 { float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -523,6 +529,11 @@ struct CollectiveMainloopFwdSm90 { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); + int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); + if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { + assert(page_size % kBlockN == 0); + assert(!args.leftpad_k); + } // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -535,14 +546,15 @@ struct CollectiveMainloopFwdSm90 { args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, - cutlass::FastDivmod(int(get<0>(args.shape_K))), + cutlass::FastDivmod(page_size), // page_size_divmod + cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, @@ -634,24 +646,24 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } - Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); - Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k) + Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); - Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) + Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) auto [tQvgQv, tQvsQv] = [&] { if constexpr (HasQv) { @@ -667,12 +679,16 @@ struct CollectiveMainloopFwdSm90 { } }(); + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx ); // Set up for transposing V, only used if Transpose_V @@ -724,9 +740,10 @@ struct CollectiveMainloopFwdSm90 { auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { pipeline_k.producer_acquire(smem_pipe_write); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tKgK_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); @@ -737,9 +754,10 @@ struct CollectiveMainloopFwdSm90 { auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); pipeline_v_load.producer_acquire(smem_pipe_write); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tVgVt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); @@ -772,8 +790,10 @@ struct CollectiveMainloopFwdSm90 { bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); if (should_load_KV) { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.template load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());} @@ -834,8 +854,10 @@ struct CollectiveMainloopFwdSm90 { PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind ++smem_pipe_write; if (should_load_KV) { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); @@ -850,33 +872,6 @@ struct CollectiveMainloopFwdSm90 { n_block_prev = n_block; if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } } - // if constexpr (Is_local) { - // Disable sink token code for now - if constexpr (false && Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - #pragma unroll 1 - for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind - ++smem_pipe_write; - if (should_load_KV) { - if constexpr (PagedKV) { - paged_kv_manager.template load_page_table(n_block); - } - if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } - load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - if constexpr (!Transpose_V) { - if constexpr (IntraWGOverlap) { - load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); - } else { - load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - } - } - } - n_block_prev = n_block; - if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } - } - } scheduler_prefetch(); if constexpr (!Transpose_V && IntraWGOverlap) { if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } @@ -1051,16 +1046,12 @@ struct CollectiveMainloopFwdSm90 { pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - clear(tOrO); - // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; - int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); @@ -1070,12 +1061,27 @@ struct CollectiveMainloopFwdSm90 { float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; softcap_val *= q_descale * k_descale; } - // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - // -inf to e.g. -50.0, which can affect the attention softmax. + // Softcapping needs to happen before masking since if we apply after masking, softcapping + // can turn -inf to e.g. -50.0, which can affect the attention softmax. auto scoremod_premask_fn = [&](auto& tSrS) { if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } }; + auto write_P_to_smem = [&](auto& tOrP) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); + }; + + auto arrive_on_P_write_barrier = [&] { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + }; + auto &barrier_Q = shared_storage.pipelines.barrier_Q; if constexpr (!AppendKV) { barrier_Q.wait(work_idx % 2); @@ -1120,7 +1126,6 @@ struct CollectiveMainloopFwdSm90 { cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } - // TODO: check the case where n_block_max <= n_block_min but there are sink tokens if constexpr (IntraWGOverlap) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); @@ -1144,19 +1149,14 @@ struct CollectiveMainloopFwdSm90 { Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!MmaPV_is_RS) { - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); - cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } --n_block; + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { static constexpr bool Check_inf = decltype(check_inf_type)::value; @@ -1192,18 +1192,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!MmaPV_is_RS) { - cutlass::arch::fence_view_async_shared(); - __syncwarp(); - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } }; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking @@ -1234,12 +1225,6 @@ struct CollectiveMainloopFwdSm90 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); - // } } // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); @@ -1266,27 +1251,48 @@ struct CollectiveMainloopFwdSm90 { auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; + auto smem_pipe_read_prev = smem_pipe_read; + if constexpr (!Is_first_iter) { ++smem_pipe_read; } Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warp_scheduler_barrier_arrive(); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + } else { + if constexpr (Is_first_iter) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + } + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + pipeline_k.consumer_release(smem_pipe_read); // release K + warpgroup_wait<0>(); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); } softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); + if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } + if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } + warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V - ++smem_pipe_read; }; auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; @@ -1319,20 +1325,20 @@ struct CollectiveMainloopFwdSm90 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - // } } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; } ++work_idx; return true; @@ -1391,26 +1397,30 @@ struct CollectiveMainloopFwdSm90 { } }; - clear(tOrO); + // clear(tOrO); // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; typename Softmax::TensorT scores_scale; int n_block = n_block_max - 1; - pipeline_v.consumer_wait(smem_pipe_read); + // If HasQv, then by the time P is ready, V must have been ready as well + if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; + #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); load_scales(scores_scale, smem_pipe_read.index()); softmax.rescale_o(tOrO, scores_scale); ++smem_pipe_read; - auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); - pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + if constexpr (!HasQv) { + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + } flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V @@ -1576,12 +1586,16 @@ struct CollectiveMainloopFwdSm90 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); @@ -1594,7 +1608,7 @@ struct CollectiveMainloopFwdSm90 { } static_assert(std::is_same_v); - static_assert(!PagedKV || std::is_same_v); + static_assert(!PagedKVNonTMA || std::is_same_v); GmemTiledCopyAppendKV gmem_tiled_copy_kv; auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) @@ -1618,7 +1632,7 @@ struct CollectiveMainloopFwdSm90 { if (get<1>(params.shape_rotary) <= 0) { pipeline_k_new.consumer_wait(smem_pipe_read); Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { Tensor tKgK_cur = tKgK(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -1629,15 +1643,15 @@ struct CollectiveMainloopFwdSm90 { } } else { Tensor gK_cur = gK(_, _, n_block); - auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); } else { auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } // Without this sync I'm getting race condition when seqlen_k is large @@ -1653,7 +1667,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v_new.consumer_wait(smem_pipe_read); int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -1668,7 +1682,7 @@ struct CollectiveMainloopFwdSm90 { #pragma unroll 1 for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { - if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } store_K(n_block, smem_pipe_read); // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } store_V(n_block, smem_pipe_read); diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 80ee61b9a41..9ea59bcc2a2 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -78,9 +78,11 @@ struct PagedKVManager { GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; cutlass::FastDivmod const &page_size_divmod; + cutlass::FastDivmod const &blockN_per_page_size_divmod; int const thread_idx; int const seqlen_k; int const leftpad_k; + int const* const ptr_page_table; GmemThrCopyKVCpAsync const gmem_thr_copy_kv; TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; @@ -88,20 +90,27 @@ struct PagedKVManager { TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; + int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA CUTLASS_DEVICE - PagedKVManager(int const* const ptr_page_table, + PagedKVManager(int const* const ptr_page_table_, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, - int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k + cutlass::FastDivmod const &blockN_per_page_size_divmod, + int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, + int bidb_kv_idx ) : page_size_divmod(page_size_divmod) + , blockN_per_page_size_divmod(blockN_per_page_size_divmod) , thread_idx(thread_idx) , seqlen_k(seqlen_k) , leftpad_k(leftpad_k) + , ptr_page_table(ptr_page_table_) , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) + , bidb_kv_idx(bidb_kv_idx) + , bidb_kv_idx_prev(bidb_kv_idx) { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); @@ -143,6 +152,38 @@ struct PagedKVManager { if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } }; + template + CUTLASS_DEVICE + void load_page_table_TMA(const int n_block) { + // We require that page size is a multiple of kBlockN, and there's no leftpad_k + if (ptr_page_table) { + bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; + } else { + n_block_idx = n_block; + } + if constexpr (First_iter && !KV_Same_Iter) { + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + } + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_K_TMA() { + return {n_block_idx, bidb_kv_idx}; + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_V_TMA() { + if constexpr (KV_Same_Iter) { + return {n_block_idx, bidb_kv_idx}; + } else { + cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + return indices; + } + }; + CUTLASS_DEVICE TensorKVPtr compute_K_ptr() { Tensor tPrKPtr = make_tensor(Shape>{}); diff --git a/hopper/rotary.h b/hopper/rotary.h index 5e30456c2d1..aa3602cc795 100644 --- a/hopper/rotary.h +++ b/hopper/rotary.h @@ -226,7 +226,7 @@ struct Rotary { // The main bottleneck here is actually instruction cache misses. - // Similar to PagedKV, it's expensive to compute the pointers. + // Similar to PagedKVNonTMA, it's expensive to compute the pointers. // We split the work among threads loading the same row, then __shfl_sync the pointers. static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); Tensor tPrCosPtr = make_tensor(Shape>{}); @@ -350,7 +350,7 @@ struct Rotary { } }; - template + template CUTLASS_DEVICE void apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) @@ -377,7 +377,7 @@ struct Rotary { CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } @@ -385,7 +385,7 @@ struct Rotary { for (int m = 0; m < size<1>(tKsK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); auto mK_cur_copy = [&] { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); @@ -400,7 +400,7 @@ struct Rotary { Tensor rK = make_fragment_like(tKsK(_, m, k)); cute::copy(tiled_copy_k, tKsK(_, m, k), rK); if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); } else { int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; @@ -412,7 +412,7 @@ struct Rotary { } }; - template + template CUTLASS_DEVICE void apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) @@ -439,7 +439,7 @@ struct Rotary { CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } @@ -449,7 +449,7 @@ struct Rotary { for (int m = 0; m < size<1>(tKcK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); Tensor gK_cur_copy = [&] { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); diff --git a/hopper/setup.py b/hopper/setup.py index 6798de67ad8..f87d809ebd5 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -90,7 +90,9 @@ def _write_ninja_file(path, objects, ldflags, library_target, - with_cuda) -> None: + with_cuda, + **kwargs, # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension + ) -> None: r"""Write a ninja file that does the desired compiling and linking. `path`: Where to write this file @@ -374,7 +376,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.93"} exe_extension = sysconfig.get_config_var("EXE") @@ -506,6 +508,7 @@ def nvcc_threads_args(): ) if not DISABLE_SPLIT: sources += ["flash_fwd_combine.cu"] + sources += ["flash_prepare_scheduler.cu"] nvcc_flags = [ "-O3", "-std=c++17", @@ -517,7 +520,7 @@ def nvcc_threads_args(): # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster "-lineinfo", "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use - # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL + "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DNDEBUG", # Important, otherwise performance is severely impacted ] diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 16cfb238416..be27f14f624 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -19,7 +19,8 @@ generate_random_padding_mask, ) -from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine, flash_attn_with_kvcache +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" @@ -103,8 +104,6 @@ def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): - # sink_token_length = 0 if not local else 4 - sink_token_length = 0 if not local else 0 if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") device = "cuda" @@ -118,7 +117,10 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -151,7 +153,6 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap ) out_pt, attn_pt = attention_ref( @@ -164,7 +165,6 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, upcast=False, reorder_ops=True, @@ -197,7 +197,6 @@ def test_flash_attn_output( qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits @@ -229,7 +228,6 @@ def test_flash_attn_output( # d ** (-0.5), # causal, # window_size[0], window_size[1], - # sink_token_length, # softcap, # deterministic, # 0, # sm_margin @@ -338,7 +336,10 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -559,8 +560,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("num_splits", [1] + ([0] if not DISABLE_SPLIT else [])) -# @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) @@ -582,13 +581,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [True]) +# @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -623,7 +622,6 @@ def test_flash_attn_kvcache( local, new_kv, mha_type, - num_splits, dtype, ): if page_size is not None and seqlen_k % page_size != 0: @@ -645,9 +643,11 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else [d] - has_qv_vals = [False] - for dv, has_qv in itertools.product(dv_vals, has_qv_vals): + dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: + has_qv = d == 64 and dv == 512 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) @@ -695,7 +695,7 @@ def test_flash_attn_kvcache( v_cache_paged, num_blocks, ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype_ref + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref ) cache_seqlens = torch.randint( 0 if new_kv else 1, @@ -823,98 +823,121 @@ def test_flash_attn_kvcache( qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - qv=qv if not varlen_q else qv_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False, True] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, + num_splits=num_splits + ) else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + scheduler_metadata = None + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype): +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype - ) + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index e67abf89a13..f713242721e 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -8,6 +8,7 @@ #include "cutlass/arch/barrier.h" #include "named_barrier.hpp" +#include "utils.h" namespace flash { @@ -19,10 +20,12 @@ struct TileSchedulerArguments { int const num_blocks, num_head, num_batch, num_splits; int const qhead_per_khead; int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr - int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling + int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling int* const tile_count_semaphore = nullptr; - int* const cu_seqlens = nullptr; - int* const seqused = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + // int const* const num_m_blocks_ptr = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -40,16 +43,20 @@ class SingleTileScheduler { int const qhead_per_khead; int const seqlen; cutlass::FastDivmod nsplits_divmod; - int* const cu_seqlens; - int* const seqused; + int const* const cu_seqlens; + int const* const seqused; + int const* const num_splits_dynamic_ptr = nullptr; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { + assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); + assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, args.qhead_per_khead, args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), - !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; + !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, + args.num_splits_dynamic_ptr}; } static dim3 @@ -61,24 +68,18 @@ class SingleTileScheduler { int block_idx = 0; int bidh = 0; int bidb = 0; - bool is_valid_tile = false; + int split_idx = 0; CUTLASS_DEVICE bool is_valid(Params const& params) const { - return is_valid_tile; + return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { - if constexpr (!Split) { - return {block_idx, bidh, bidb, 0 /*split_idx*/}; - } else { - int split_idx; - int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - return {block_idx, bidh_actual, bidb, split_idx}; - } + return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; } }; @@ -90,14 +91,27 @@ class SingleTileScheduler { CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { - WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; + if constexpr (Split) { + int split_idx; + work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); + work_info.split_idx = split_idx; + } + bool is_valid_tile = true; if constexpr (Varlen) { int seqlen = params.seqused ? params.seqused[work_info.bidb] : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen; + is_valid_tile = work_info.block_idx * kBlock < seqlen; } + if constexpr (Varlen && Split) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + // Use the top 16 bits to store num_splits + work_info.split_idx |= (num_splits_dynamic << 16); + is_valid_tile &= work_info.split_idx < num_splits_dynamic; + } + work_info.bidb = is_valid_tile ? work_info.bidb : -1; return work_info; } @@ -113,7 +127,7 @@ class SingleTileScheduler { CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {-1, -1, -1, false}; + return {0, 0, -1, 0}; } }; @@ -232,7 +246,7 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2; + int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead @@ -341,7 +355,6 @@ class DynamicPersistentTileScheduler { }; - template class VarlenDynamicPersistentTileScheduler { @@ -361,10 +374,13 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; - int* const cu_seqlens; - int* const seqused; + int const* const cu_seqlens; + int const* const seqused; + // int* const num_m_blocks_ptr; + int const* const num_splits_dynamic_ptr; }; static Params @@ -372,10 +388,15 @@ class VarlenDynamicPersistentTileScheduler { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); - return {args.num_head * (!Split ? 1 : args.num_splits), args.num_batch, + assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), - args.tile_count_semaphore, args.cu_seqlens, args.seqused}; + args.tile_count_semaphore, args.cu_seqlens, args.seqused, + // args.num_m_blocks_ptr, + args.num_splits_dynamic_ptr}; } static dim3 @@ -399,8 +420,18 @@ class VarlenDynamicPersistentTileScheduler { if constexpr (!Split) { return {block, bidh, bidb, 0 /*split_idx*/}; } else { - int split_idx; - int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh); + uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; + int bidh_actual = reinterpret_cast(bidh_actual_u); + // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx + uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); + int split_idx = reinterpret_cast(split_idx_u); + // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // if (threadIdx.x == 128) { + // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); + // } return {block, bidh_actual, bidb, split_idx}; } } @@ -412,43 +443,55 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE WorkTileInfo tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { - auto prefix_sum = [](int val) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { - int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i); - if (lane >= i) { val += partial_sum; } + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + auto get_num_m_blocks = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlock) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } - return val; + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlock) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; }; - auto get_num_m_blocks = [&](int bidb_start) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - int seqlen; - if (params.seqused) { - seqlen = lane + bidb_start < params.num_batch ? params.seqused[lane + bidb_start] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = lane + bidb_start <= params.num_batch ? params.cu_seqlens[lane + bidb_start] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; - } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - return lane + bidb_start < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; + auto get_num_splits = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? (!Split ? 1 : (params.num_splits_dynamic_ptr + ? params.num_splits_dynamic_ptr[batch_idx] + : params.nsplits_divmod.divisor)) + : 0; }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane + int num_splits = get_num_splits(current_work.bidb); + int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; // Cumulative number of blocks for the next 31 batches - int num_m_blocks_cumulative = prefix_sum(num_m_blocks); + int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // Only the lower 16 bits are the actual bidh + int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + if constexpr (Split) { + int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + } int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); // } + // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } while (group_end_tile <= next_tile_idx) { bidb += cutlass::NumThreadsPerWarp - 1; if (bidb >= params.num_batch) { @@ -458,7 +501,9 @@ class VarlenDynamicPersistentTileScheduler { return {next_tile_idx, 0, 0, params.num_batch}; } num_m_blocks = get_num_m_blocks(bidb); - num_m_blocks_cumulative = prefix_sum(num_m_blocks); + num_splits = get_num_splits(bidb); + num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); group_end_tile += m_blocks_in_group * params.num_head; // if (blockIdx.x <= 9 && threadIdx.x == 0) { @@ -469,13 +514,29 @@ class VarlenDynamicPersistentTileScheduler { // The next problem to process is the first one that does not have ending tile position // that is greater than or equal to tile index. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); + // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); + if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; int bidh = mh_block / num_m_blocks; int block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("blockIdx.x = %d, threadIdx.x = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } return {next_tile_idx, block, bidh, bidb}; } diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 5d0bd6e2634..2c440c6e210 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -9,25 +9,26 @@ // Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, - bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 - // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 // Switch to tile size 192 x 192 for now - return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, true}; + bool const use_blockN_128 = is_causal || is_local; + return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { - return {192, is_local || paged_kv ? 128 : 144, false, true}; + return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; + return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; // {128, 192, false, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { - return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -37,11 +38,11 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, 128, true, true}; } else if (headdim <= 128) { - return {128, paged_kv ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; } else if (headdim <= 192) { - return {128, (paged_kv || softcap) && is_local ? 128 : 160, true, true}; + return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; } else { - return {128, is_local ? 64 : 128, true, !paged_kv}; // PagedKV uses more registers so we disabled IntraWGOverlap + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap } } } diff --git a/hopper/utils.h b/hopper/utils.h index e14ca157439..3f76ea66e97 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -21,17 +21,7 @@ #include #include - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while(0) - -#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) +#include "cuda_check.h" namespace flash { @@ -272,9 +262,11 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } + static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); + static constexpr int kMaxKIters = 16; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { if constexpr (!SwapAB) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } else { @@ -282,6 +274,22 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } + // In the case of large kNumKIters, the compiler chooses to store the smem addresses + // in registers, causing spills. This loop forces the compiler to recompute the addresses. + if constexpr (kNumKIters > kMaxKIters) { + // This will always be zero, just a way to force the compiler to recompute the smem + // addresses. This results in USEL instructions. There's probably a better way to do this. + int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; + CUTLASS_PRAGMA_UNROLL + for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); @@ -625,6 +633,19 @@ CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTE_DEVICE T warp_prefix_sum(T val) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { + T partial_sum = __shfl_up_sync(0xffffffff, val, i); + if (lane >= i) { val += partial_sum; } + } + return val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template CUTE_DEVICE T warp_uniform(T a) { return __shfl_sync(0xffffffff, a, 0); diff --git a/tests/test_vllm_flash_attn.py b/tests/test_vllm_flash_attn.py index a49ce478294..ddc2ea03832 100644 --- a/tests/test_vllm_flash_attn.py +++ b/tests/test_vllm_flash_attn.py @@ -13,7 +13,8 @@ from vllm_flash_attn.flash_attn_interface import ( flash_attn_varlen_func, flash_attn_with_kvcache, - is_fa_version_supported + get_scheduler_metadata, + is_fa_version_supported, ) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] @@ -185,6 +186,7 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("aot_schedule", [True, False]) @pytest.mark.parametrize("fa_version", VERSIONS) @torch.inference_mode() def test_flash_attn_with_paged_kv( @@ -195,6 +197,7 @@ def test_flash_attn_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + aot_schedule: bool, fa_version: int, ) -> None: torch.set_default_device("cuda") @@ -221,6 +224,24 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + scheduler_metadata = None + if aot_schedule: + if fa_version == 2: + pytest.skip("AOT schedule is not supported in version 2") + scheduler_metadata = get_scheduler_metadata( + batch_size=num_seqs, + max_seqlen_q=1, + max_seqlen_k=max_kv_len, + num_heads_q=num_query_heads, + num_heads_kv=num_kv_heads, + headdim=head_size, + cache_seqlens=kv_lens_tensor, + qkv_dtype=dtype, + causal=True, + window_size=(-1, -1), + has_softcap=soft_cap is not None + ) + output = flash_attn_with_kvcache( query.unsqueeze(1), key_cache, @@ -230,6 +251,7 @@ def test_flash_attn_with_paged_kv( block_table=block_tables, cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, + scheduler_metadata=scheduler_metadata, fa_version=fa_version, ).squeeze(1) @@ -255,6 +277,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("aot_schedule", [True, False]) @pytest.mark.parametrize("fa_version", VERSIONS) @torch.inference_mode() def test_varlen_with_paged_kv( @@ -266,6 +289,7 @@ def test_varlen_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + aot_schedule: bool, fa_version: int, ) -> None: torch.set_default_device("cuda") @@ -303,6 +327,25 @@ def test_varlen_with_paged_kv( num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + + scheduler_metadata = None + if aot_schedule: + if fa_version == 2: + pytest.skip("AOT schedule is not supported in version 2") + scheduler_metadata = get_scheduler_metadata( + batch_size=num_seqs, + max_seqlen_q=1, + max_seqlen_k=max_kv_len, + num_heads_q=num_query_heads, + num_heads_kv=num_kv_heads, + headdim=head_size, + cache_seqlens=seqused_k, + qkv_dtype=dtype, + causal=True, + window_size=(-1, -1), + has_softcap=soft_cap is not None + ) + output = flash_attn_varlen_func( q=query, k=key_cache, @@ -316,6 +359,7 @@ def test_varlen_with_paged_kv( window_size=window_size, block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, + scheduler_metadata=scheduler_metadata, fa_version=fa_version ) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 806598ca489..6c524f9ed3b 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -73,6 +73,48 @@ def fa_version_unsupported_reason(fa_version: int, device = None) \ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +# NOTE only used in FA3 +def get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication +): + cache_seqlens = maybe_contiguous(cache_seqlens) + if headdim_v is None: + headdim_v = headdim + scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, + qkv_dtype, + cache_seqlens, + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_leftpad, + page_size, + max_seqlen_k_new, + causal, + window_size[0], window_size[1], + has_softcap, + num_splits, + pack_gqa, + sm_margin, + ) + + return scheduler_metadata + def flash_attn_varlen_func( q, @@ -95,10 +137,13 @@ def flash_attn_varlen_func( block_table=None, return_softmax_lse=False, out=None, - fa_version: int = DEFAULT_FA_VERSION, + # FA3 Only + scheduler_metadata=None, q_descale=None, k_descale=None, v_descale=None, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -173,6 +218,12 @@ def flash_attn_varlen_func( dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) if fa_version == 2: + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale" + ) out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( q, k, v, out, @@ -216,9 +267,9 @@ def flash_attn_varlen_func( softmax_scale, causal, real_window_size[0], real_window_size[1], - 0, # sink_token_length softcap, True, # rotary_interleaved + scheduler_metadata, 0, # num_splits None, # pack_gqa 0, # sm_margin @@ -250,10 +301,13 @@ def flash_attn_with_kvcache( return_softmax_lse=False, *, out=None, - fa_version: int = DEFAULT_FA_VERSION, + # FA3 Only + scheduler_metadata=None, q_descale=None, k_descale=None, v_descale=None, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -355,6 +409,12 @@ def flash_attn_with_kvcache( block_table = maybe_contiguous(block_table) if fa_version == 2: + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale" + ) out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache( q, k_cache, v_cache, k, v, # k_new, v_new @@ -393,9 +453,9 @@ def flash_attn_with_kvcache( softmax_scale, causal, window_size[0], window_size[1], - 0, # sink_token_length softcap, rotary_interleaved, # rotary_interleaved + scheduler_metadata, num_splits, # num_splits None, # pack_gqa 0, # sm_margin