From 630d80bc02554c393b19ac9a979c035ceb1fc470 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Thu, 7 Aug 2025 08:32:49 -0700 Subject: [PATCH 01/20] use LPT order in varlen kernel --- hopper/tile_scheduler.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 1f90f66adc2..f1471493a90 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -629,6 +629,9 @@ class VarlenDynamicPersistentTileScheduler { int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; int bidh = mh_block / num_m_blocks; int block = mh_block - bidh * num_m_blocks; + // Longest-processing-time-first + next_tile_idx += num_m_blocks - 1 - 2 * block; + block = num_m_blocks - 1 - block; if constexpr (Split) { int bidh_actual = bidh / num_splits; int split_idx = bidh - bidh_actual * num_splits; From e2d900a3f7bab2e262dda5c715696c6a73a7dfba Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Thu, 7 Aug 2025 08:39:16 -0700 Subject: [PATCH 02/20] add prefill decode benchmark script --- hopper/benchmark_prefill_decode.py | 217 +++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 hopper/benchmark_prefill_decode.py diff --git a/hopper/benchmark_prefill_decode.py b/hopper/benchmark_prefill_decode.py new file mode 100644 index 00000000000..55b74c1a60e --- /dev/null +++ b/hopper/benchmark_prefill_decode.py @@ -0,0 +1,217 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +Timing = NamedTuple('timing', [('mean', float)]) + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +# from flash_attn_interface import flash_attn_func as flash_attn_func_v3 +from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3 +from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 + +from triton.testing import do_bench + +cudnn = None +triton_attention = None + +DISABLE_BACKWARD = True + +def time_fwd(func, *args, repeats=10, verbose=True, desc="", **kwargs): + # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (-1, -1): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = True +page_size = None +softcap = 0.0 +V_colmajor = False +deterministic = False + +# decode_batches = 128 +# prefill_batches = 1 +# batch_size = decode_batches + prefill_batches +# decode_seqlen_k = 8192 if decode_batches > 0 else 0 +# prefill_seqlen_k = 2048 if prefill_batches > 0 else 0 +# # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative +# seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) +# decode_seqlen_q = 1 if decode_batches > 0 else 0 +# prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 +# seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches +# max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) + +time_f = {} +time_b = {} + +prefill_first_vals = [False, True] +# prefill_first_vals = [False] + +for headdim in [128]: + for prefill_batches in [0, 1]: + # for decode_batches in range(32, 128 + 32, 8): + for decode_batches in [128]: + batch_size = decode_batches + prefill_batches + decode_seqlen_k = 8192 if decode_batches > 0 else 0 + prefill_seqlen_k = 1024 if prefill_batches > 0 else 0 + # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative + seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) + decode_seqlen_q = 1 if decode_batches > 0 else 0 + prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 + seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches + max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) + + tp_degree=1 + nheads = 64//tp_degree + nheads_kv = 8//tp_degree + # nheads = 1 + # nheads_kv = 1 + headdim_v = headdim + has_qv = False + + # window_size = (128, 0) + window_size = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + + # print("Window size: ", window_size) + # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") + # print("Head dim: ", headdim) + # print("Batch size: ", batch_size) + # print("Prefill seqlen k: ", prefill_seqlen_k) + # print("Decode seqlen k: ", decode_seqlen_k) + # print("Seqlen k (max): ", seqlen_k) + # print("Prefill seqlen q: ", prefill_seqlen_q) + # print("Decode seqlen q: ", decode_seqlen_q) + # print("Seqlen q (total): ", seqlen_q) + + num_splits = 1 + pack_gqa = None + + # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") + + if prefill_batches == 0: + this_prefill_first_vals = [False] + else: + this_prefill_first_vals = prefill_first_vals + + for prefill_first in this_prefill_first_vals: + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + + seqlen_q_decode_offset = decode_seqlen_q * decode_batches + seqlen_q_prefill_offset = prefill_seqlen_q * prefill_batches + seqlen_k_decode_offset = decode_seqlen_k * decode_batches + + if prefill_first: + cu_seqlens_q_prefill = torch.arange(prefill_batches, device=device, dtype=torch.int32) * prefill_seqlen_q + cu_seqlens_q_decode = torch.arange(decode_batches + 1, device=device, dtype=torch.int32) * decode_seqlen_q + seqlen_q_prefill_offset + cu_seqlens_q = torch.cat((cu_seqlens_q_prefill, cu_seqlens_q_decode), dim=0) + else: + cu_seqlens_q_decode = torch.arange(decode_batches, device=device, dtype=torch.int32) * decode_seqlen_q + cu_seqlens_q_prefill = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + seqlen_q_decode_offset + cu_seqlens_q = torch.cat((cu_seqlens_q_decode, cu_seqlens_q_prefill), dim=0) + + cache_seqlens_decode = torch.ones(decode_batches, dtype=torch.int32, device=device) * decode_seqlen_k + cache_seqlens_prefill = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k + + if prefill_first: + cache_seqlens = torch.cat((cache_seqlens_prefill, cache_seqlens_decode), dim=0) + else: + cache_seqlens = torch.cat((cache_seqlens_decode, cache_seqlens_prefill), dim=0) + + + # print("q: ", q.shape) + # print("k: ", k.shape) + # print("v: ", v.shape) + # print("cu seqlens q: ", cu_seqlens_q.shape) + # print("cache seqlens: ", cache_seqlens.shape) + # print("cu seqlens q vals: ", cu_seqlens_q) + # print("cache seqlens vals: ", cache_seqlens) + + page_table = None + + # for causal in [False, True]: + for causal in [True]: + print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {decode_seqlen_k = }, {num_splits = }, {prefill_first = }, {decode_batches = }, {prefill_batches = } ###") + # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + decode_nFLOPS = flops(decode_batches, nheads, decode_seqlen_q, decode_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + nFLOPS = decode_nFLOPS + prefill_nFLOPS + + bytes_kv = (decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) + bytes_qo = (decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials + bytes = bytes_kv + bytes_qo + print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') + + time.sleep(1) + m1 = time_fwd(flash_attn_func_v3, + q, + k, + v, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + repeats=repeats, verbose=verbose, desc='Fav3') + + time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean + + print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') + From f9b85d611f4659188d7b394fbe2d436828187737 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Tue, 12 Aug 2025 13:27:30 -0700 Subject: [PATCH 03/20] add sort in prepare --- hopper/flash.h | 5 +- hopper/flash_api.cpp | 18 +++-- hopper/flash_attn_interface.py | 2 + hopper/flash_fwd_launch_template.h | 2 +- hopper/flash_prepare_scheduler.cu | 116 ++++++++++++++++++++++++----- hopper/setup.py | 18 +++++ hopper/tile_scheduler.hpp | 7 +- 7 files changed, 139 insertions(+), 29 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index bee89e5f054..b78611930bd 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -152,9 +152,10 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; + int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; + int * __restrict__ batch_idx_ptr; // virtual -> actual bool skip_scheduler_metadata_computation; int arch; @@ -211,7 +212,7 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl, bool sort_batches); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 33185bf2304..929e03c5e97 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -525,8 +525,8 @@ mha_fwd_get_scheduler_metadata( bool has_softcap, int64_t num_splits, std::optional pack_gqa_, - int64_t sm_margin - ) { + bool sort_batches, + int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); @@ -604,7 +604,10 @@ mha_fwd_get_scheduler_metadata( at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + int num_metadata_batch_vectors = sort_batches ? 3 : 1; + tile_count_semaphore = torch::empty( + {int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b * num_metadata_batch_vectors}, + opts.dtype(torch::kInt32)); if (scheduler_needs_semaphore) { if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -612,6 +615,8 @@ mha_fwd_get_scheduler_metadata( params.tile_count_semaphore = nullptr; } params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + 1 + params.b : nullptr; + params.batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + 1 + params.b * 2 : nullptr; } if (params.num_splits_dynamic_ptr) { @@ -620,7 +625,7 @@ mha_fwd_get_scheduler_metadata( int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); auto stream = at::cuda::getCurrentCUDAStream().stream(); - prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/, sort_batches); CHECK_CUDA_KERNEL_LAUNCH(); } return tile_count_semaphore; @@ -956,7 +961,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b * 3; params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { at::Tensor scheduler_metadata = scheduler_metadata_.value(); @@ -973,6 +978,8 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + params.num_m_blocks_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 + params.b : nullptr; + params.batch_idx_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 + params.b * 2 : nullptr; } if (q_v_.has_value()) { @@ -1705,6 +1712,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "bool has_softcap = False," "int num_splits = 0," "bool? pack_gqa = None," + "bool sort_batches = False," "int sm_margin = 0) -> Tensor"); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5547f426da5..5ae04a05789 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -808,6 +808,7 @@ def get_scheduler_metadata( num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication + sort_batches=False, ): cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: @@ -829,6 +830,7 @@ def get_scheduler_metadata( has_softcap, num_splits, pack_gqa, + sort_batches, sm_margin, ) return scheduler_metadata diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b8af2977f11..e7d7b14cc07 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -156,7 +156,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/, false /*sort_batches*/); CHECK_CUDA_KERNEL_LAUNCH(); } diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 7093fff32b6..8a471e58859 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -2,6 +2,7 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ +#include #include "cutlass/fast_math.h" #include "cutlass/barrier.h" #include "cutlass/arch/barrier.h" @@ -10,8 +11,42 @@ #include "flash.h" +#include "static_switch.h" + namespace flash { +// needs (tensor size = batches): +// 1. batch_idx_ptr: virtual_batch_idx -> batch_idx +// 2. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] +// 3. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + +// Custom comparison functor for descending order +template +struct CustomMore +{ + __device__ bool operator()(const DataType &lhs, const DataType &rhs) + { + return lhs > rhs; + } +}; + +// Specialization for int2 +template <> +struct CustomMore { + __device__ bool operator()(const int2& lhs, const int2& rhs) const { + return lhs.x > rhs.x; + } +}; + +// Specialization for int4 +template <> +struct CustomMore { + __device__ bool operator()(const int4& lhs, const int4& rhs) const { + return lhs.x > rhs.x; + } +}; + +template __global__ void prepare_varlen_num_blocks_kernel( int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, @@ -19,15 +54,24 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, - // int* const num_m_blocks_ptr, + int* const num_m_blocks_ptr, // virtual_batch_idx -> num_m_blocks[batch_idx] int* const num_splits_dynamic_ptr, + int* const batch_idx_ptr, bool enable_pdl) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; + static constexpr int BLOCK_THREADS = 1024; + static constexpr int ITEMS_PER_THREAD = 1; + static_assert(BLOCK_THREADS * ITEMS_PER_THREAD == 1024); + using BlockMergeSort = cub::BlockMergeSort; + // Assume that there's only one block in the grid __shared__ int total_blocks_smem[kSmemSize]; + // Allocate shared memory for BlockMergeSort operations + __shared__ typename BlockMergeSort::TempStorage temp_storage; + // There's only 1 block in the grid, so might as well start launching the main attn kernel if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } @@ -38,8 +82,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_m_blocks = [&](int batch_idx) { int seqlen; if (seqused_q) { seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; @@ -55,8 +98,7 @@ __global__ void prepare_varlen_num_blocks_kernel( ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; - auto get_num_n_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_n_blocks = [&](int batch_idx) { int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; int seqlen; if (seqused_k) { @@ -84,8 +126,9 @@ __global__ void prepare_varlen_num_blocks_kernel( int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; int bidb_start = kNumBatchPerWarp * warp_idx; - int num_m_blocks = get_num_m_blocks(bidb_start); - int num_n_blocks = get_num_n_blocks(bidb_start); + int batch_idx = lane + bidb_start; + int num_m_blocks = get_num_m_blocks(batch_idx); + int num_n_blocks = get_num_n_blocks(batch_idx); int total_blocks = num_m_blocks * num_n_blocks; // Warp sum @@ -100,25 +143,58 @@ __global__ void prepare_varlen_num_blocks_kernel( int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + + if constexpr(Sort) { + num_n_blocks = cute::ceil_div(num_n_blocks, num_splits_dynamic); // num_n_blocks per batch accounting for splits + + // At this point, I have total blocks in shared memory, and each thread has its num_m_blocks, num_n_blocks, and num_splits + // thread batch_idx = bidb_start + lane, lane 31 is a dummy. + + // Goal: sort batches by num_n_blocks + if(lane == kNumBatchPerWarp) { + num_n_blocks = -1; // sort last + } + int4 thread_keys[ITEMS_PER_THREAD]; // 1 item per thread for now + thread_keys[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + + BlockMergeSort(temp_storage).Sort(thread_keys, CustomMore()); + + if (threadIdx.x < num_batch) { + num_m_blocks_ptr[threadIdx.x] = thread_keys[0].y; + num_splits_dynamic_ptr[threadIdx.x] = thread_keys[0].z; + batch_idx_ptr[threadIdx.x] = thread_keys[0].w; + } + } else { + if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } } } } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, - int blockM, int blockN, bool enable_pdl) { + int blockM, int blockN, bool enable_pdl, bool sort_batches) { // Only support batch <= 992 (32 warps, each with 31 batches) int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( - params.seqlen_q, params.seqlen_k, params.seqlen_knew, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, - cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, - // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, enable_pdl); + printf("sort_batches = %d.\n", sort_batches); + BOOL_SWITCH(sort_batches, Sort, [&] { + if constexpr(Sort) { + printf("Sorting it!.\n"); + } else { + printf("No sorting.\n"); + } + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + params.batch_idx_ptr, + enable_pdl); + }); } diff --git a/hopper/setup.py b/hopper/setup.py index c15c438f56c..53fc9db103d 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -62,6 +62,24 @@ DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" DISABLE_SM8x = os.getenv("FLASH_ATTENTION_DISABLE_SM80", "FALSE") == "TRUE" +DISABLE_BACKWARD = True +# DISABLE_SPLIT = True +DISABLE_PAGEDKV = True +DISABLE_APPENDKV = True +DISABLE_LOCAL = True +DISABLE_SOFTCAP = True +# DISABLE_PACKGQA = True +DISABLE_FP16 = True +DISABLE_FP8 = True +# DISABLE_VARLEN = True +DISABLE_CLUSTER = True +DISABLE_HDIM64 = True +DISABLE_HDIM96 = True +# DISABLE_HDIM128 = True +DISABLE_HDIM192 = True +DISABLE_HDIM256 = True +DISABLE_SM8x = True + ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index f1471493a90..f65699f28fa 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -513,6 +513,11 @@ class VarlenDynamicPersistentTileScheduler { } struct WorkTileInfo { + // for LPT scheduling: + // 1) tile_idx is offset by (reverse_block - block) + // 2) block <- reverse_block = num_m_blocks[bidb] - 1 - block + // NOTE: we only use tile_idx in tile_idx_to_work_tile to compute (tile_idx - block), + // so just need tile_idx' = tile_idx + block' - block for any block' replacing block int tile_idx, block, bidh, bidb; CUTLASS_DEVICE @@ -630,7 +635,7 @@ class VarlenDynamicPersistentTileScheduler { int bidh = mh_block / num_m_blocks; int block = mh_block - bidh * num_m_blocks; // Longest-processing-time-first - next_tile_idx += num_m_blocks - 1 - 2 * block; + next_tile_idx += num_m_blocks - 1 - 2 * block; // add (reverse_block - block) to linear tile_idx block = num_m_blocks - 1 - block; if constexpr (Split) { int bidh_actual = bidh / num_splits; From 13422594b52121fba713d01a8073ecaf176feb7b Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Tue, 12 Aug 2025 20:23:48 -0700 Subject: [PATCH 04/20] add full implementation: --- hopper/flash.h | 4 +- hopper/flash_api.cpp | 49 +++-- hopper/flash_attn_interface.py | 17 +- hopper/flash_fwd_combine_kernel.h | 11 +- hopper/flash_fwd_combine_launch_template.h | 2 +- hopper/flash_fwd_launch_template.h | 5 +- hopper/flash_prepare_scheduler.cu | 82 ++++---- hopper/prepare_varlen_bench.py | 234 +++++++++++++++++++++ hopper/setup.py | 8 +- hopper/static_switch.h | 23 ++ hopper/test_flash_attn.py | 84 +++++--- hopper/tile_scheduler.hpp | 49 +++-- 12 files changed, 455 insertions(+), 113 deletions(-) create mode 100644 hopper/prepare_varlen_bench.py diff --git a/hopper/flash.h b/hopper/flash.h index b78611930bd..b33567ef383 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -155,8 +155,10 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; - int * __restrict__ batch_idx_ptr; // virtual -> actual + int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual bool skip_scheduler_metadata_computation; + bool varlen_sort_batches; + int tile_count_semaphore_offset; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 929e03c5e97..00b893ae18c 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -39,6 +39,8 @@ PyObject* PyInit__C(void) #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define PREPARE_VARLEN_MAX_BATCHES 992 + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -585,7 +587,7 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= 992; + bool const use_dynamic_split = params.b <= PREPARE_VARLEN_MAX_BATCHES; params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); @@ -603,20 +605,27 @@ mha_fwd_get_scheduler_metadata( // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; - if (scheduler_needs_semaphore || use_dynamic_split) { - int num_metadata_batch_vectors = sort_batches ? 3 : 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = sort_batches; + if (scheduler_needs_semaphore || use_dynamic_split) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_dynamic_split ? 1 : 0; + if(sort_batches) { num_prepare_batch_vectors += 2; } + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); tile_count_semaphore = torch::empty( - {int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b * num_metadata_batch_vectors}, + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, opts.dtype(torch::kInt32)); + // {num_splits_dynamic, num_m_blocks, virtual_batch_indices, tile_count_semaphore} + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; if (scheduler_needs_semaphore) { if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; } - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; - params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + 1 + params.b : nullptr; - params.batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + 1 + params.b * 2 : nullptr; } if (params.num_splits_dynamic_ptr) { @@ -674,6 +683,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql std::optional scheduler_metadata_, // (b + 1) int64_t num_splits, std::optional pack_gqa_, + bool sort_batches, int64_t sm_margin ) { @@ -945,7 +955,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; + bool const use_dynamic_split = is_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); @@ -960,8 +970,14 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + params.varlen_sort_batches = sort_batches; if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b * 3; + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_dynamic_split ? 1 : 0; + if(sort_batches) { num_prepare_batch_vectors += 2; } + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { at::Tensor scheduler_metadata = scheduler_metadata_.value(); @@ -976,10 +992,12 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql if (scheduler_needs_semaphore && !use_dynamic_split) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } - params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; - params.num_m_blocks_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 + params.b : nullptr; - params.batch_idx_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 + params.b * 2 : nullptr; + // {num_splits_dynamic, num_m_blocks, virtual_batch_indices, tile_count_semaphore} + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } if (q_v_.has_value()) { @@ -1141,7 +1159,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case - tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); + tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -1659,6 +1677,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "Tensor? scheduler_metadata = None," "int num_splits = 0," "bool? pack_gqa = None," + "bool sort_batches = False," "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); m.def("bwd(" "Tensor dout," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5ae04a05789..628717dbed6 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -50,7 +50,9 @@ def _flash_attn_forward( scheduler_metadata=None, num_splits=1, pack_gqa=None, - sm_margin=0): + sm_margin=0, + varlen_sort_batches=False, + ): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -96,6 +98,7 @@ def _flash_attn_forward( scheduler_metadata, num_splits, pack_gqa, + varlen_sort_batches, sm_margin, ) return out, softmax_lse, *rest @@ -362,6 +365,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + varlen_sort_batches=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -391,6 +395,7 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, + varlen_sort_batches=varlen_sort_batches, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -436,7 +441,7 @@ def backward(ctx, dout, *args): dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -599,6 +604,7 @@ def flash_attn_varlen_func( pack_gqa=None, deterministic=False, sm_margin=0, + varlen_sort_batches=False, ): return FlashAttnVarlenFunc.apply( q, @@ -621,6 +627,7 @@ def flash_attn_varlen_func( pack_gqa, deterministic, sm_margin, + varlen_sort_batches, ) @@ -659,6 +666,7 @@ def flash_attn_with_kvcache( pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, + varlen_sort_batches=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -786,6 +794,7 @@ def flash_attn_with_kvcache( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, + varlen_sort_batches=varlen_sort_batches, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out @@ -808,7 +817,7 @@ def get_scheduler_metadata( num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication - sort_batches=False, + varlen_sort_batches=False, ): cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: @@ -830,7 +839,7 @@ def get_scheduler_metadata( has_softcap, num_splits, pack_gqa, - sort_batches, + varlen_sort_batches, sm_margin, ) return scheduler_metadata diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index a22e05969d9..81370c731c7 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -145,6 +145,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -164,6 +165,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -187,7 +189,9 @@ class FlashAttnFwdCombine { args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, - args.semaphore_to_reset + args.varlen_batch_idx_ptr, + args.semaphore_to_reset, + }; } @@ -203,8 +207,9 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = blockIdx.z; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + int const virtual_batch_idx = blockIdx.z; + int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[virtual_batch_idx] : virtual_batch_idx; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[virtual_batch_idx] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 11d422924b4..a2ff25dcd5f 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -35,7 +35,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index e7d7b14cc07..72ab4b27638 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -151,12 +151,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.seqlen_q, params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + params.num_m_blocks_ptr, + params.varlen_batch_idx_ptr }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/, false /*sort_batches*/); + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/, params.varlen_sort_batches); CHECK_CUDA_KERNEL_LAUNCH(); } diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 8a471e58859..e0a1f9ae94f 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -16,7 +16,7 @@ namespace flash { // needs (tensor size = batches): -// 1. batch_idx_ptr: virtual_batch_idx -> batch_idx +// 1. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx // 2. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] // 3. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] @@ -46,7 +46,7 @@ struct CustomMore { } }; -template +template __global__ void prepare_varlen_num_blocks_kernel( int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, @@ -56,15 +56,15 @@ __global__ void prepare_varlen_num_blocks_kernel( int* const tile_count_semaphore, int* const num_m_blocks_ptr, // virtual_batch_idx -> num_m_blocks[batch_idx] int* const num_splits_dynamic_ptr, - int* const batch_idx_ptr, + int* const varlen_batch_idx_ptr, bool enable_pdl) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; - static constexpr int BLOCK_THREADS = 1024; + static constexpr int BLOCK_DIM_X = NumWarps * 32; static constexpr int ITEMS_PER_THREAD = 1; - static_assert(BLOCK_THREADS * ITEMS_PER_THREAD == 1024); - using BlockMergeSort = cub::BlockMergeSort; + static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); + using BlockMergeSort = cub::BlockMergeSort; // Assume that there's only one block in the grid __shared__ int total_blocks_smem[kSmemSize]; @@ -129,7 +129,6 @@ __global__ void prepare_varlen_num_blocks_kernel( int batch_idx = lane + bidb_start; int num_m_blocks = get_num_m_blocks(batch_idx); int num_n_blocks = get_num_n_blocks(batch_idx); - int total_blocks = num_m_blocks * num_n_blocks; // Warp sum #pragma unroll @@ -145,28 +144,36 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); if constexpr(Sort) { - num_n_blocks = cute::ceil_div(num_n_blocks, num_splits_dynamic); // num_n_blocks per batch accounting for splits - - // At this point, I have total blocks in shared memory, and each thread has its num_m_blocks, num_n_blocks, and num_splits - // thread batch_idx = bidb_start + lane, lane 31 is a dummy. - - // Goal: sort batches by num_n_blocks - if(lane == kNumBatchPerWarp) { + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); + + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { num_n_blocks = -1; // sort last } - int4 thread_keys[ITEMS_PER_THREAD]; // 1 item per thread for now - thread_keys[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_splits_dynamic, num_m_blocks, batch_idx); + + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks = %d, num_splits = %d, num_m_blocks = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); - BlockMergeSort(temp_storage).Sort(thread_keys, CustomMore()); + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, CustomMore()); + + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks = %d, num_splits = %d, num_m_blocks = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); if (threadIdx.x < num_batch) { - num_m_blocks_ptr[threadIdx.x] = thread_keys[0].y; - num_splits_dynamic_ptr[threadIdx.x] = thread_keys[0].z; - batch_idx_ptr[threadIdx.x] = thread_keys[0].w; + num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].y; + num_m_blocks_ptr[threadIdx.x] = batch_coords[0].z; + varlen_batch_idx_ptr[threadIdx.x] = batch_coords[0].w; } } else { - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } @@ -178,23 +185,20 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo int blockM, int blockN, bool enable_pdl, bool sort_batches) { // Only support batch <= 992 (32 warps, each with 31 batches) int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - printf("sort_batches = %d.\n", sort_batches); + int num_warps = cutlass::ceil_div(params.b, 31); BOOL_SWITCH(sort_batches, Sort, [&] { - if constexpr(Sort) { - printf("Sorting it!.\n"); - } else { - printf("No sorting.\n"); - } - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( - params.seqlen_q, params.seqlen_k, params.seqlen_knew, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, - cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, - params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, - params.batch_idx_ptr, - enable_pdl); + NUM_WARP_SWITCH(num_warps, NumWarps, [&] { + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 32 * NumWarps /*block*/, 0, stream>>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + params.varlen_batch_idx_ptr, + enable_pdl); + }); }); } diff --git a/hopper/prepare_varlen_bench.py b/hopper/prepare_varlen_bench.py new file mode 100644 index 00000000000..c0e9ea9ed14 --- /dev/null +++ b/hopper/prepare_varlen_bench.py @@ -0,0 +1,234 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +Timing = NamedTuple('timing', [('mean', float)]) + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +# from flash_attn_interface import flash_attn_func as flash_attn_func_v3 +from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3, get_scheduler_metadata +from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 + +from triton.testing import do_bench + +cudnn = None +triton_attention = None + +DISABLE_BACKWARD = True + +def time_fwd(func, *args, repeats=10, verbose=True, desc="", **kwargs): + # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (-1, -1): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = True +page_size = None +softcap = 0.0 +V_colmajor = False +deterministic = False + +# decode_batches = 128 +# prefill_batches = 1 +# batch_size = decode_batches + prefill_batches +# decode_seqlen_k = 8192 if decode_batches > 0 else 0 +# prefill_seqlen_k = 2048 if prefill_batches > 0 else 0 +# # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative +# seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) +# decode_seqlen_q = 1 if decode_batches > 0 else 0 +# prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 +# seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches +# max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) + +time_f = {} +time_b = {} + +prefill_first_vals = [False, True] +# prefill_first_vals = [False] + +for headdim in [128]: + for prefill_batches in [0]: + # for decode_batches in range(32, 128 + 32, 8): + for decode_batches in [128]: + for sort_batches in [False, True]: + + batch_size = decode_batches + prefill_batches + decode_seqlen_k = 8192 if decode_batches > 0 else 0 + prefill_seqlen_k = 1024 if prefill_batches > 0 else 0 + # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative + seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) + decode_seqlen_q = 1 if decode_batches > 0 else 0 + prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 + seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches + max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) + + tp_degree=1 + nheads = 64//tp_degree + nheads_kv = 8//tp_degree + # nheads = 1 + # nheads_kv = 1 + headdim_v = headdim + has_qv = False + + # window_size = (128, 0) + window_size = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + + # print("Window size: ", window_size) + # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") + # print("Head dim: ", headdim) + # print("Batch size: ", batch_size) + # print("Prefill seqlen k: ", prefill_seqlen_k) + # print("Decode seqlen k: ", decode_seqlen_k) + # print("Seqlen k (max): ", seqlen_k) + # print("Prefill seqlen q: ", prefill_seqlen_q) + # print("Decode seqlen q: ", decode_seqlen_q) + # print("Seqlen q (total): ", seqlen_q) + + num_splits = 1 + pack_gqa = None + + # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") + + if prefill_batches == 0: + this_prefill_first_vals = [False] + else: + this_prefill_first_vals = prefill_first_vals + + for prefill_first in this_prefill_first_vals: + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + + seqlen_q_decode_offset = decode_seqlen_q * decode_batches + seqlen_q_prefill_offset = prefill_seqlen_q * prefill_batches + seqlen_k_decode_offset = decode_seqlen_k * decode_batches + + if prefill_first: + cu_seqlens_q_prefill = torch.arange(prefill_batches, device=device, dtype=torch.int32) * prefill_seqlen_q + cu_seqlens_q_decode = torch.arange(decode_batches + 1, device=device, dtype=torch.int32) * decode_seqlen_q + seqlen_q_prefill_offset + cu_seqlens_q = torch.cat((cu_seqlens_q_prefill, cu_seqlens_q_decode), dim=0) + else: + cu_seqlens_q_decode = torch.arange(decode_batches, device=device, dtype=torch.int32) * decode_seqlen_q + cu_seqlens_q_prefill = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + seqlen_q_decode_offset + cu_seqlens_q = torch.cat((cu_seqlens_q_decode, cu_seqlens_q_prefill), dim=0) + + cache_seqlens_decode = torch.ones(decode_batches, dtype=torch.int32, device=device) * decode_seqlen_k + cache_seqlens_prefill = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k + + if prefill_first: + cache_seqlens = torch.cat((cache_seqlens_prefill, cache_seqlens_decode), dim=0) + else: + cache_seqlens = torch.cat((cache_seqlens_decode, cache_seqlens_prefill), dim=0) + + + # print("q: ", q.shape) + # print("k: ", k.shape) + # print("v: ", v.shape) + # print("cu seqlens q: ", cu_seqlens_q.shape) + # print("cache seqlens: ", cache_seqlens.shape) + # print("cu seqlens q vals: ", cu_seqlens_q) + # print("cache seqlens vals: ", cache_seqlens) + + page_table = None + + # for causal in [False, True]: + for causal in [True]: + print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {decode_seqlen_k = }, {num_splits = }, {prefill_first = }, {decode_batches = }, {prefill_batches = } ###") + # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + decode_nFLOPS = flops(decode_batches, nheads, decode_seqlen_q, decode_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + nFLOPS = decode_nFLOPS + prefill_nFLOPS + + bytes_kv = (decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) + bytes_qo = (decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials + bytes = bytes_kv + bytes_qo + # print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') + + # time.sleep(1) + # m1 = time_fwd(flash_attn_func_v3, + # q, + # k, + # v, + # cache_seqlens=cache_seqlens, + # cu_seqlens_q=cu_seqlens_q, + # max_seqlen_q=max_seqlen_q, + # causal=causal, + # window_size=window_size, + # softcap=softcap, + # num_splits=num_splits, + # pack_gqa=pack_gqa, + # repeats=repeats, verbose=verbose, desc='Fav3') + + # time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean + + # print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') + + scheduler_metadata = get_scheduler_metadata( + batch_size, max_seqlen_q, seqlen_k, nheads, nheads_kv, headdim, + cache_seqlens, q.dtype, headdim_v=headdim, cu_seqlens_q=cu_seqlens_q, + causal=causal, num_splits=num_splits, varlen_sort_batches=sort_batches, ) + + # m1 = time_fwd(get_scheduler_metadata, + # batch_size, max_seqlen_q, seqlen_k, nheads, nheads_kv, headdim, + # cache_seqlens, q.dtype, headdim_v=headdim, cu_seqlens_q=cu_seqlens_q, + # causal=causal, num_splits=num_splits, sort_batches=sort_batches, + # repeats=repeats, verbose=verbose, desc='Prepare' + # ) + + # time_f[(causal, headdim, batch_size, seqlen_k), "Prepare"] = m1.mean + + # print(f'Prepare: {m1.mean * 1e3:.3f}ms') diff --git a/hopper/setup.py b/hopper/setup.py index 53fc9db103d..6d5d552b904 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -62,15 +62,15 @@ DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" DISABLE_SM8x = os.getenv("FLASH_ATTENTION_DISABLE_SM80", "FALSE") == "TRUE" -DISABLE_BACKWARD = True +# DISABLE_BACKWARD = True # DISABLE_SPLIT = True -DISABLE_PAGEDKV = True +# DISABLE_PAGEDKV = True DISABLE_APPENDKV = True -DISABLE_LOCAL = True +# DISABLE_LOCAL = True DISABLE_SOFTCAP = True # DISABLE_PACKGQA = True DISABLE_FP16 = True -DISABLE_FP8 = True +# DISABLE_FP8 = True # DISABLE_VARLEN = True DISABLE_CLUSTER = True DISABLE_HDIM64 = True diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 5e13b5f93a8..15a7d51364b 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -179,3 +179,26 @@ return __VA_ARGS__(); \ } \ }() + +#define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE <= 1) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index f1247e689da..aaabe78efa4 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -39,6 +39,24 @@ DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" +# DISABLE_BACKWARD = True +# DISABLE_SPLIT = True +# DISABLE_PAGEDKV = True +DISABLE_APPENDKV = True +# DISABLE_LOCAL = True +DISABLE_SOFTCAP = True +# DISABLE_PACKGQA = True +DISABLE_FP16 = True +# DISABLE_FP8 = True +# DISABLE_VARLEN = True +DISABLE_CLUSTER = True +DISABLE_HDIM64 = True +DISABLE_HDIM96 = True +# DISABLE_HDIM128 = True +DISABLE_HDIM192 = True +DISABLE_HDIM256 = True +DISABLE_SM8x = True + COMPILED_HDIMS = ( [] + ([64] if not DISABLE_HDIM64 else []) @@ -281,8 +299,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @@ -290,10 +308,10 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) -@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) -# @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +@pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) @@ -304,8 +322,10 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) -@pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("varlen_sort_batches", [False, True]) +# @pytest.mark.parametrize("varlen_sort_batches", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -332,14 +352,15 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, varlen_sort_batches, ): device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 - batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 9 if seqlen_q <= 2048 else 2 + batch_size = 32 nheads = 6 # batch_size = 2 # nheads = 1 @@ -458,7 +479,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + # pack_gqa_vals = [True] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + # num_splits_vals = [1] + # print("cu_seqlens_q: ", cu_seqlens_q) + # print("cu_seqlens_k: ", cu_seqlens_k) + # print("seqused_q: ", seqused_q) + # print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out_unpad = flash_attn_varlen_func( q_unpad, @@ -477,6 +504,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, + varlen_sort_batches=varlen_sort_batches, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -580,16 +608,16 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) -@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) -# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("new_kv", [False]) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(False, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) -# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) @@ -597,15 +625,17 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize("varlen_sort_batches", [False, True]) +# @pytest.mark.parametrize("varlen_sort_batches", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -642,6 +672,7 @@ def test_flash_attn_kvcache( new_kv, mha_type, dtype, + varlen_sort_batches, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() @@ -667,8 +698,11 @@ def test_flash_attn_kvcache( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + attention_chunk_vals = [0] # debug + # dv_vals = [d] # debug for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: @@ -850,9 +884,10 @@ def test_flash_attn_kvcache( sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() - num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: scheduler_metadata = get_scheduler_metadata( batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, @@ -860,7 +895,7 @@ def test_flash_attn_kvcache( cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, - num_splits=num_splits + num_splits=num_splits, varlen_sort_batches=varlen_sort_batches, ) else: scheduler_metadata = None @@ -895,7 +930,8 @@ def test_flash_attn_kvcache( rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, - return_softmax_lse=True + return_softmax_lse=True, + varlen_sort_batches=varlen_sort_batches, ) if varlen_q: out = output_pad_fn(out) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index f65699f28fa..b0004697f8a 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -24,8 +24,9 @@ struct TileSchedulerArguments { int* const tile_count_semaphore = nullptr; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; - // int const* const num_m_blocks_ptr = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const num_m_blocks_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -487,8 +488,9 @@ class VarlenDynamicPersistentTileScheduler { int* const tile_count_semaphore; int const* const cu_seqlens; int const* const seqused; - // int* const num_m_blocks_ptr; int const* const num_splits_dynamic_ptr; + int const* const num_m_blocks_ptr; + int const* const varlen_batch_idx_ptr; }; static Params @@ -503,8 +505,9 @@ class VarlenDynamicPersistentTileScheduler { cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; + args.num_splits_dynamic_ptr, + args.num_m_blocks_ptr, + args.varlen_batch_idx_ptr}; } static dim3 @@ -530,8 +533,9 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { + int actual_bidb = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[bidb] : bidb; if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; + return {block, bidh, actual_bidb, 0 /*split_idx*/}; } else { // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift @@ -545,7 +549,7 @@ class VarlenDynamicPersistentTileScheduler { // if (threadIdx.x == 128) { // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); // } - return {block, bidh_actual, bidb, split_idx}; + return {block, bidh_actual, actual_bidb, split_idx}; } } }; @@ -559,22 +563,27 @@ class VarlenDynamicPersistentTileScheduler { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; + if (params.num_m_blocks_ptr) { + int num_m_blocks = batch_idx < params.num_batch ? params.num_m_blocks_ptr[batch_idx] : 0; + return lane < cutlass::NumThreadsPerWarp - 1 ? num_m_blocks : 0; + } else { + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlock) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlock) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; } - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; - // ? params.num_m_blocks_ptr[batch_idx] : 0; }; auto get_num_splits = [&] (int bidb_start) { From a587dd3bb61e2d600c1de19281a421083e71e46c Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Fri, 15 Aug 2025 00:00:54 -0700 Subject: [PATCH 05/20] add varlen kvhead swizzle --- hopper/flash.h | 1 + hopper/flash_api.cpp | 16 ++-- hopper/flash_fwd_launch_template.h | 7 +- hopper/flash_prepare_scheduler.cu | 42 ++++++++-- hopper/tile_scheduler.hpp | 129 ++++++++++++++++++++--------- 5 files changed, 141 insertions(+), 54 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index b33567ef383..cabff5c9798 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -156,6 +156,7 @@ struct Flash_fwd_params : public Qkv_params { // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual + int * __restrict__ num_nheads_in_l2_ptr; bool skip_scheduler_metadata_computation; bool varlen_sort_batches; int tile_count_semaphore_offset; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 00b893ae18c..63f073212c0 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -609,7 +609,7 @@ mha_fwd_get_scheduler_metadata( params.varlen_sort_batches = sort_batches; if (scheduler_needs_semaphore || use_dynamic_split) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_dynamic_split ? 1 : 0; + int num_prepare_batch_vectors = use_dynamic_split ? 2 : 0; if(sort_batches) { num_prepare_batch_vectors += 2; } int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); @@ -618,8 +618,10 @@ mha_fwd_get_scheduler_metadata( opts.dtype(torch::kInt32)); // {num_splits_dynamic, num_m_blocks, virtual_batch_indices, tile_count_semaphore} params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() : nullptr; - params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.num_nheads_in_l2_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 3 : nullptr; if (scheduler_needs_semaphore) { if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; @@ -973,7 +975,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.varlen_sort_batches = sort_batches; if (scheduler_needs_semaphore || use_dynamic_split) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_dynamic_split ? 1 : 0; + int num_prepare_batch_vectors = use_dynamic_split ? 2 : 0; if(sort_batches) { num_prepare_batch_vectors += 2; } int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; @@ -994,8 +996,10 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } // {num_splits_dynamic, num_m_blocks, virtual_batch_indices, tile_count_semaphore} params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() : nullptr; - params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.num_nheads_in_l2_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 3 : nullptr; params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 72ab4b27638..92f1b007e99 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -58,7 +58,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, Is_causal /*LPT*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -149,11 +149,12 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, params.dv, sizeof(Element), + params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.num_m_blocks_ptr, - params.varlen_batch_idx_ptr + params.varlen_batch_idx_ptr, + params.num_nheads_in_l2_ptr }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index e0a1f9ae94f..eaeade17410 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -57,7 +57,11 @@ __global__ void prepare_varlen_num_blocks_kernel( int* const num_m_blocks_ptr, // virtual_batch_idx -> num_m_blocks[batch_idx] int* const num_splits_dynamic_ptr, int* const varlen_batch_idx_ptr, - bool enable_pdl) { + // int* const num_n_blocks_ptr, + int* const num_nheads_in_l2_ptr, + bool enable_pdl, + bool packgqa, + int max_kvblocks_in_l2) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; @@ -93,7 +97,7 @@ __global__ void prepare_varlen_num_blocks_kernel( } else { seqlen = seqlen_q_static; } - seqlen *= qhead_per_khead; + if(packgqa) { seqlen *= qhead_per_khead; } return batch_idx < num_batch && lane < kNumBatchPerWarp ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; @@ -142,11 +146,20 @@ __global__ void prepare_varlen_num_blocks_kernel( int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); + + auto get_nheads_in_l2 = [&](int n_blocks) { + int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 + : n_blocks * 8 <= max_kvblocks_in_l2 ? 8 + : n_blocks * 4 <= max_kvblocks_in_l2 ? 4 + : n_blocks * 2 <= max_kvblocks_in_l2 ? 2 + : 1; + if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } + return min(nheads_in_l2, num_head); + }; if constexpr(Sort) { - // num_n_blocks per work tile for the batch - num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); - if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { num_n_blocks = -1; // sort last } @@ -164,15 +177,19 @@ __global__ void prepare_varlen_num_blocks_kernel( // if (threadIdx.x == 0) { // printf("Sorted: num_n_blocks = %d, num_splits = %d, num_m_blocks = %d, batch_idx = %d.\n", // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); - // } __syncthreads(); + // } __syncthreads(); if (threadIdx.x < num_batch) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].y; num_m_blocks_ptr[threadIdx.x] = batch_coords[0].z; varlen_batch_idx_ptr[threadIdx.x] = batch_coords[0].w; } } else { if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } @@ -184,8 +201,13 @@ __global__ void prepare_varlen_num_blocks_kernel( void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl, bool sort_batches) { // Only support batch <= 992 (32 warps, each with 31 batches) - int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); + // int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); + int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); int num_warps = cutlass::ceil_div(params.b, 31); + int const size_l2 = 50 * 1024 * 1024; // 50 MB + int const element_size = params.is_e4m3 ? 1 : 2; + int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; + int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; BOOL_SWITCH(sort_batches, Sort, [&] { NUM_WARP_SWITCH(num_warps, NumWarps, [&] { flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 32 * NumWarps /*block*/, 0, stream>>>( @@ -198,7 +220,11 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, - enable_pdl); + // params.num_n_blocks_ptr, + params.num_nheads_in_l2_ptr, + enable_pdl, + packgqa, + max_kvblocks_in_l2); }); }); } diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index b0004697f8a..82e3b4cd480 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -27,6 +27,8 @@ struct TileSchedulerArguments { int const* const num_splits_dynamic_ptr = nullptr; int const* const num_m_blocks_ptr = nullptr; int const* const varlen_batch_idx_ptr = nullptr; + // int const* const num_n_blocks_ptr = nullptr; + int const* const num_nheads_in_l2_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -464,7 +466,7 @@ class SingleTileBwdLPTScheduler { /////////////////////////////////////////////////////////////////////////////// -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -483,6 +485,7 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + // int const max_kvblocks_in_l2; cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; @@ -491,6 +494,8 @@ class VarlenDynamicPersistentTileScheduler { int const* const num_splits_dynamic_ptr; int const* const num_m_blocks_ptr; int const* const varlen_batch_idx_ptr; + // int const* const num_n_blocks_ptr; + int const* const num_nheads_in_l2_ptr; }; static Params @@ -500,14 +505,20 @@ class VarlenDynamicPersistentTileScheduler { assert(args.tile_count_semaphore != nullptr); assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size; + // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock; return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + // max_kvblocks_in_l2, cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, args.num_m_blocks_ptr, - args.varlen_batch_idx_ptr}; + args.varlen_batch_idx_ptr, + // aras.num_n_blocks_ptr, + args.num_nheads_in_l2_ptr}; } static dim3 @@ -516,11 +527,6 @@ class VarlenDynamicPersistentTileScheduler { } struct WorkTileInfo { - // for LPT scheduling: - // 1) tile_idx is offset by (reverse_block - block) - // 2) block <- reverse_block = num_m_blocks[bidb] - 1 - block - // NOTE: we only use tile_idx in tile_idx_to_work_tile to compute (tile_idx - block), - // so just need tile_idx' = tile_idx + block' - block for any block' replacing block int tile_idx, block, bidh, bidb; CUTLASS_DEVICE @@ -568,7 +574,7 @@ class VarlenDynamicPersistentTileScheduler { return lane < cutlass::NumThreadsPerWarp - 1 ? num_m_blocks : 0; } else { int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { + if (seqlen > kBlockM) { if (params.seqused) { seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; } else if (params.cu_seqlens) { @@ -581,7 +587,7 @@ class VarlenDynamicPersistentTileScheduler { if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; + ? cute::ceil_div(seqlen, kBlockM) : 0; // ? params.num_m_blocks_ptr[batch_idx] : 0; } }; @@ -603,12 +609,14 @@ class VarlenDynamicPersistentTileScheduler { // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); // Only the lower 16 bits are the actual bidh - int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); - int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes - if constexpr (Split) { - int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; - group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); - } + // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // if constexpr (Split) { + // int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + // group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + // } + // NEW: current_work.tile_idx holds group_start_tile for starting batch + int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head; // Same for all lanes int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); @@ -640,30 +648,77 @@ class VarlenDynamicPersistentTileScheduler { bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - // Longest-processing-time-first - next_tile_idx += num_m_blocks - 1 - 2 * block; // add (reverse_block - block) to linear tile_idx - block = num_m_blocks - 1 - block; - if constexpr (Split) { - int bidh_actual = bidh / num_splits; - int split_idx = bidh - bidh_actual * num_splits; - // TODO: idk why this gives wrong answer nondeterministically - // int bidh_actual, split_idx; - // split_idx = params.head_divmod.divmod(bidh_actual, bidh); - // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int mh_block = next_tile_idx - group_start_tile; + int block, bidh; + if constexpr (LPT) { + // if constexpr(!Split) { + if (!Split || num_splits == 1) { + // NOTE: code for computing nheads_in_l2 directly left as demonstration + // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; + // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks + // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); + // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } + // nheads_in_l2 = min(nheads_in_l2, params.num_head); + + int nheads_in_l2 = params.num_nheads_in_l2_ptr ? params.num_nheads_in_l2_ptr[bidb] : params.num_head; + int mh_in_l2 = nheads_in_l2 * num_m_blocks; + int section_idx = mh_block / mh_in_l2; + int l2_mod = mh_block - section_idx * mh_in_l2; + // tail section + int nheads_remainder = params.num_head - section_idx * nheads_in_l2; + int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder; + block = l2_mod / nheads_in_this_section; + int bidh_residual = l2_mod - block * nheads_in_this_section; + bidh = section_idx * nheads_in_l2 + bidh_residual; + + if constexpr(Split) { + // remember to set num_splits = 1 in work tile + uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } else { + // bidh = mh_block / num_m_blocks; + // block = mh_block - bidh * num_m_blocks; + // if constexpr (Split) { + // int bidh_actual = bidh / num_splits; + // int split_idx = bidh - bidh_actual * num_splits; + // uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // bidh = reinterpret_cast(bidh_packed); + // } + // assume enough space in L2 cache if split + block = params.head_divmod.divmod(bidh, mh_block); + if constexpr (Split) { + int split_idx = block / num_m_blocks; + block = block - split_idx * num_m_blocks; + uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } + block = num_m_blocks - 1 - block; + } else { + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } - bidh = reinterpret_cast(bidh_packed); } - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; + return {group_start_tile, block, bidh, bidb}; } template From 0890577c3e63dbaab166478451ecc6e234ddf7d7 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 10:33:31 -0700 Subject: [PATCH 06/20] add settings for swizzle ablation --- hopper/benchmark_prefill.py | 191 +++++++++++++++++++++++ hopper/benchmark_prefill_decode.py | 234 +++++++++++++++-------------- hopper/flash.h | 1 + hopper/flash_api.cpp | 6 +- hopper/flash_attn_interface.py | 10 +- hopper/flash_fwd_launch_template.h | 2 +- hopper/flash_prepare_scheduler.cu | 4 +- hopper/setup.py | 7 +- hopper/test_flash_attn.py | 50 +++--- hopper/tile_scheduler.hpp | 29 ++-- 10 files changed, 376 insertions(+), 158 deletions(-) create mode 100644 hopper/benchmark_prefill.py diff --git a/hopper/benchmark_prefill.py b/hopper/benchmark_prefill.py new file mode 100644 index 00000000000..096183d932f --- /dev/null +++ b/hopper/benchmark_prefill.py @@ -0,0 +1,191 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +Timing = NamedTuple('timing', [('mean', float)]) + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +# from flash_attn_interface import flash_attn_func as flash_attn_func_v3 +from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3 +from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 + +from triton.testing import do_bench + +cudnn = None +triton_attention = None + +DISABLE_BACKWARD = True + +def time_fwd(func, *args, repeats=10, verbose=True, desc="", **kwargs): + # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (-1, -1): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = True +page_size = None +softcap = 0.0 +V_colmajor = False +deterministic = False + +# decode_batches = 128 +# prefill_batches = 1 +# batch_size = decode_batches + prefill_batches +# decode_seqlen_k = 8192 if decode_batches > 0 else 0 +# prefill_seqlen_k = 2048 if prefill_batches > 0 else 0 +# # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative +# seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) +# decode_seqlen_q = 1 if decode_batches > 0 else 0 +# prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 +# seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches +# max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) + +time_f = {} +time_b = {} + + +for headdim in [128]: + for head_swizzle in [True, False]: + print("\nHead Swizzle is ", head_swizzle) + for prefill_batches in [16, 8, 4, 2, 1]: + # for prefill_batches in [4, 4, 4]: + batch_size = prefill_batches + # prefill_seqlen_k = (8192 * 8) // prefill_batches + prefill_seqlen_k = (8192 * 4) // prefill_batches + seqlen_k = prefill_seqlen_k + prefill_seqlen_q = prefill_seqlen_k + seqlen_q = prefill_seqlen_q * prefill_batches + max_seqlen_q = prefill_seqlen_q + + # tp_degree=4 + # nheads = 64//tp_degree + # nheads_kv = 8//tp_degree + nheads = 8 + nheads_kv = 8 + headdim_v = headdim + has_qv = False + + # window_size = (128, 0) + window_size = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + + # print("Window size: ", window_size) + # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") + # print("Head dim: ", headdim) + # print("Batch size: ", batch_size) + # print("Prefill seqlen k: ", prefill_seqlen_k) + # print("Decode seqlen k: ", decode_seqlen_k) + # print("Seqlen k (max): ", seqlen_k) + # print("Prefill seqlen q: ", prefill_seqlen_q) + # print("Decode seqlen q: ", decode_seqlen_q) + # print("Seqlen q (total): ", seqlen_q) + + num_splits = 1 + pack_gqa = False + + # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") + + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + + cu_seqlens_q = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + cache_seqlens = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k + + # print("q: ", q.shape) + # print("k: ", k.shape) + # print("v: ", v.shape) + # print("cu seqlens q: ", cu_seqlens_q.shape) + # print("cache seqlens: ", cache_seqlens.shape) + # print("cu seqlens q vals: ", cu_seqlens_q) + # print("cache seqlens vals: ", cache_seqlens) + + page_table = None + + # for causal in [False, True]: + for causal in [True]: + print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {prefill_batches = }, {num_splits = }, {head_swizzle = } ###") + # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + + prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + nFLOPS = prefill_nFLOPS + + bytes_kv = (prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) + bytes_qo = (prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials + bytes = bytes_kv + bytes_qo + print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') + + time.sleep(1) + m1 = time_fwd(flash_attn_func_v3, + q, + k, + v, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + head_swizzle=head_swizzle, + repeats=repeats, verbose=verbose, desc='Fav3') + + time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean + + print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') + time.sleep(2) + diff --git a/hopper/benchmark_prefill_decode.py b/hopper/benchmark_prefill_decode.py index 55b74c1a60e..d5345bb1534 100644 --- a/hopper/benchmark_prefill_decode.py +++ b/hopper/benchmark_prefill_decode.py @@ -90,128 +90,130 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w time_f = {} time_b = {} -prefill_first_vals = [False, True] -# prefill_first_vals = [False] +# prefill_first_vals = [False, True] +prefill_first_vals = [False] for headdim in [128]: - for prefill_batches in [0, 1]: + for prefill_batches in [8]: # for decode_batches in range(32, 128 + 32, 8): - for decode_batches in [128]: - batch_size = decode_batches + prefill_batches - decode_seqlen_k = 8192 if decode_batches > 0 else 0 - prefill_seqlen_k = 1024 if prefill_batches > 0 else 0 - # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative - seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) - decode_seqlen_q = 1 if decode_batches > 0 else 0 - prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 - seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches - max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) - - tp_degree=1 - nheads = 64//tp_degree - nheads_kv = 8//tp_degree - # nheads = 1 - # nheads_kv = 1 - headdim_v = headdim - has_qv = False - - # window_size = (128, 0) - window_size = (-1, -1) - # window_size = (seqlen // 2 - 1, 0) - - # print("Window size: ", window_size) - # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") - # print("Head dim: ", headdim) - # print("Batch size: ", batch_size) - # print("Prefill seqlen k: ", prefill_seqlen_k) - # print("Decode seqlen k: ", decode_seqlen_k) - # print("Seqlen k (max): ", seqlen_k) - # print("Prefill seqlen q: ", prefill_seqlen_q) - # print("Decode seqlen q: ", decode_seqlen_q) - # print("Seqlen q (total): ", seqlen_q) - - num_splits = 1 - pack_gqa = None - - # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") - - if prefill_batches == 0: - this_prefill_first_vals = [False] - else: - this_prefill_first_vals = prefill_first_vals - - for prefill_first in this_prefill_first_vals: - leftpad_k = None - # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) - q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) - q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] - v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() - v_fa3 = v if not V_colmajor else v_colmajor - qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + for decode_batches in [1]: + for head_swizzle in [False, True]: + batch_size = decode_batches + prefill_batches + decode_seqlen_k = 8192 if decode_batches > 0 else 0 + prefill_seqlen_k = 1024 * 8 if prefill_batches > 0 else 0 + # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative + seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) + decode_seqlen_q = 1 if decode_batches > 0 else 0 + prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 + seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches + max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) + + # tp_degree=8 + # nheads = 64//tp_degree + # nheads_kv = 8//tp_degree + nheads = 64 + nheads_kv = 8 + headdim_v = headdim + has_qv = False - seqlen_q_decode_offset = decode_seqlen_q * decode_batches - seqlen_q_prefill_offset = prefill_seqlen_q * prefill_batches - seqlen_k_decode_offset = decode_seqlen_k * decode_batches + # window_size = (128, 0) + window_size = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + + # print("Window size: ", window_size) + # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") + # print("Head dim: ", headdim) + # print("Batch size: ", batch_size) + # print("Prefill seqlen k: ", prefill_seqlen_k) + # print("Decode seqlen k: ", decode_seqlen_k) + # print("Seqlen k (max): ", seqlen_k) + # print("Prefill seqlen q: ", prefill_seqlen_q) + # print("Decode seqlen q: ", decode_seqlen_q) + # print("Seqlen q (total): ", seqlen_q) + + num_splits = 1 + pack_gqa = None + + # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") - if prefill_first: - cu_seqlens_q_prefill = torch.arange(prefill_batches, device=device, dtype=torch.int32) * prefill_seqlen_q - cu_seqlens_q_decode = torch.arange(decode_batches + 1, device=device, dtype=torch.int32) * decode_seqlen_q + seqlen_q_prefill_offset - cu_seqlens_q = torch.cat((cu_seqlens_q_prefill, cu_seqlens_q_decode), dim=0) - else: - cu_seqlens_q_decode = torch.arange(decode_batches, device=device, dtype=torch.int32) * decode_seqlen_q - cu_seqlens_q_prefill = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + seqlen_q_decode_offset - cu_seqlens_q = torch.cat((cu_seqlens_q_decode, cu_seqlens_q_prefill), dim=0) - - cache_seqlens_decode = torch.ones(decode_batches, dtype=torch.int32, device=device) * decode_seqlen_k - cache_seqlens_prefill = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k - - if prefill_first: - cache_seqlens = torch.cat((cache_seqlens_prefill, cache_seqlens_decode), dim=0) + if prefill_batches == 0: + this_prefill_first_vals = [False] else: - cache_seqlens = torch.cat((cache_seqlens_decode, cache_seqlens_prefill), dim=0) - - - # print("q: ", q.shape) - # print("k: ", k.shape) - # print("v: ", v.shape) - # print("cu seqlens q: ", cu_seqlens_q.shape) - # print("cache seqlens: ", cache_seqlens.shape) - # print("cu seqlens q vals: ", cu_seqlens_q) - # print("cache seqlens vals: ", cache_seqlens) - - page_table = None - - # for causal in [False, True]: - for causal in [True]: - print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {decode_seqlen_k = }, {num_splits = }, {prefill_first = }, {decode_batches = }, {prefill_batches = } ###") - # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) - decode_nFLOPS = flops(decode_batches, nheads, decode_seqlen_q, decode_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - nFLOPS = decode_nFLOPS + prefill_nFLOPS - - bytes_kv = (decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) - bytes_qo = (decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials - bytes = bytes_kv + bytes_qo - print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') - - time.sleep(1) - m1 = time_fwd(flash_attn_func_v3, - q, - k, - v, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - repeats=repeats, verbose=verbose, desc='Fav3') + this_prefill_first_vals = prefill_first_vals + + for prefill_first in this_prefill_first_vals: + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + + seqlen_q_decode_offset = decode_seqlen_q * decode_batches + seqlen_q_prefill_offset = prefill_seqlen_q * prefill_batches + seqlen_k_decode_offset = decode_seqlen_k * decode_batches - time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean + if prefill_first: + cu_seqlens_q_prefill = torch.arange(prefill_batches, device=device, dtype=torch.int32) * prefill_seqlen_q + cu_seqlens_q_decode = torch.arange(decode_batches + 1, device=device, dtype=torch.int32) * decode_seqlen_q + seqlen_q_prefill_offset + cu_seqlens_q = torch.cat((cu_seqlens_q_prefill, cu_seqlens_q_decode), dim=0) + else: + cu_seqlens_q_decode = torch.arange(decode_batches, device=device, dtype=torch.int32) * decode_seqlen_q + cu_seqlens_q_prefill = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + seqlen_q_decode_offset + cu_seqlens_q = torch.cat((cu_seqlens_q_decode, cu_seqlens_q_prefill), dim=0) + + cache_seqlens_decode = torch.ones(decode_batches, dtype=torch.int32, device=device) * decode_seqlen_k + cache_seqlens_prefill = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k + + if prefill_first: + cache_seqlens = torch.cat((cache_seqlens_prefill, cache_seqlens_decode), dim=0) + else: + cache_seqlens = torch.cat((cache_seqlens_decode, cache_seqlens_prefill), dim=0) + + + # print("q: ", q.shape) + # print("k: ", k.shape) + # print("v: ", v.shape) + # print("cu seqlens q: ", cu_seqlens_q.shape) + # print("cache seqlens: ", cache_seqlens.shape) + # print("cu seqlens q vals: ", cu_seqlens_q) + # print("cache seqlens vals: ", cache_seqlens) - print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') + page_table = None + + # for causal in [False, True]: + for causal in [True]: + print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {decode_seqlen_k = }, {num_splits = }, {prefill_first = }, {decode_batches = }, {prefill_batches = }, {head_swizzle = } ###") + # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + decode_nFLOPS = flops(decode_batches, nheads, decode_seqlen_q, decode_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) + nFLOPS = decode_nFLOPS + prefill_nFLOPS + + bytes_kv = (decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) + bytes_qo = (decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials + bytes = bytes_kv + bytes_qo + print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') + + time.sleep(1) + m1 = time_fwd(flash_attn_func_v3, + q, + k, + v, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + head_swizzle=head_swizzle, + repeats=repeats, verbose=verbose, desc='Fav3') + + time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean + + print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') diff --git a/hopper/flash.h b/hopper/flash.h index cabff5c9798..03f66bbc186 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -160,6 +160,7 @@ struct Flash_fwd_params : public Qkv_params { bool skip_scheduler_metadata_computation; bool varlen_sort_batches; int tile_count_semaphore_offset; + bool head_swizzle; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 63f073212c0..14bbc569590 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -686,7 +686,8 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql int64_t num_splits, std::optional pack_gqa_, bool sort_batches, - int64_t sm_margin + int64_t sm_margin, + bool head_swizzle ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -1003,6 +1004,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } + params.head_swizzle = head_swizzle; if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); @@ -1682,7 +1684,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "int num_splits = 0," "bool? pack_gqa = None," "bool sort_batches = False," - "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + "int sm_margin = 0, bool head_swizzle = False) -> (Tensor(out!), Tensor, Tensor, Tensor)"); m.def("bwd(" "Tensor dout," "Tensor q," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 628717dbed6..0c7cc828639 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -52,6 +52,7 @@ def _flash_attn_forward( pack_gqa=None, sm_margin=0, varlen_sort_batches=False, + head_swizzle=False, ): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v @@ -100,6 +101,7 @@ def _flash_attn_forward( pack_gqa, varlen_sort_batches, sm_margin, + head_swizzle, ) return out, softmax_lse, *rest @@ -366,6 +368,7 @@ def forward( deterministic=False, sm_margin=0, varlen_sort_batches=False, + head_swizzle=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -396,6 +399,7 @@ def forward( pack_gqa=pack_gqa, sm_margin=sm_margin, varlen_sort_batches=varlen_sort_batches, + head_swizzle=head_swizzle, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -441,7 +445,7 @@ def backward(ctx, dout, *args): dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -605,6 +609,7 @@ def flash_attn_varlen_func( deterministic=False, sm_margin=0, varlen_sort_batches=False, + head_swizzle=False, ): return FlashAttnVarlenFunc.apply( q, @@ -628,6 +633,7 @@ def flash_attn_varlen_func( deterministic, sm_margin, varlen_sort_batches, + head_swizzle ) @@ -667,6 +673,7 @@ def flash_attn_with_kvcache( sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, varlen_sort_batches=False, + head_swizzle=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -795,6 +802,7 @@ def flash_attn_with_kvcache( pack_gqa=pack_gqa, sm_margin=sm_margin, varlen_sort_batches=varlen_sort_batches, + head_swizzle=head_swizzle, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 92f1b007e99..96e6e4d67f6 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -154,7 +154,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_splits_dynamic_ptr, params.num_m_blocks_ptr, params.varlen_batch_idx_ptr, - params.num_nheads_in_l2_ptr + params.head_swizzle ? params.num_nheads_in_l2_ptr : nullptr // can toggle for swizzle ablation testing }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index eaeade17410..be3bd0de610 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -204,9 +204,11 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo // int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); int num_warps = cutlass::ceil_div(params.b, 31); - int const size_l2 = 50 * 1024 * 1024; // 50 MB + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice int const element_size = params.is_e4m3 ? 1 : 2; int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; + // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; BOOL_SWITCH(sort_batches, Sort, [&] { NUM_WARP_SWITCH(num_warps, NumWarps, [&] { diff --git a/hopper/setup.py b/hopper/setup.py index 6d5d552b904..4b4b47e7587 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -65,9 +65,9 @@ # DISABLE_BACKWARD = True # DISABLE_SPLIT = True # DISABLE_PAGEDKV = True -DISABLE_APPENDKV = True +# DISABLE_APPENDKV = True # DISABLE_LOCAL = True -DISABLE_SOFTCAP = True +# DISABLE_SOFTCAP = True # DISABLE_PACKGQA = True DISABLE_FP16 = True # DISABLE_FP8 = True @@ -419,7 +419,8 @@ def nvcc_threads_args(): # ptxas 12.8 gives the best perf currently # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. - if bare_metal_version != Version("12.8"): + if bare_metal_version != Version("12.8") and bare_metal_version != Version("12.9"): + print("Bare Metal Version is: ", bare_metal_version) download_and_copy( name="nvcc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index aaabe78efa4..14ad659bbd4 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -42,9 +42,9 @@ # DISABLE_BACKWARD = True # DISABLE_SPLIT = True # DISABLE_PAGEDKV = True -DISABLE_APPENDKV = True +# DISABLE_APPENDKV = True # DISABLE_LOCAL = True -DISABLE_SOFTCAP = True +# DISABLE_SOFTCAP = True # DISABLE_PACKGQA = True DISABLE_FP16 = True # DISABLE_FP8 = True @@ -313,9 +313,9 @@ def test_flash_attn_output( # @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("add_unused_qkv", [False, True]) -# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @@ -325,7 +325,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("d", COMPILED_HDIMS) @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("varlen_sort_batches", [False, True]) -# @pytest.mark.parametrize("varlen_sort_batches", [True]) +# @pytest.mark.parametrize("varlen_sort_batches", [False]) +@pytest.mark.parametrize("head_swizzle", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -348,29 +349,35 @@ def test_flash_attn_output( (1024, 1024), (1023, 1024), (1024, 1023), + (1024, 1024), (2048, 2048), + (4096, 4096), ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, varlen_sort_batches, + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, varlen_sort_batches, head_swizzle, ): device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 - # batch_size = 9 if seqlen_q <= 2048 else 2 - batch_size = 32 + batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # nheads_kv = nheads + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -479,7 +486,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - # pack_gqa_vals = [True] + # pack_gqa_vals = [False] num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] # num_splits_vals = [1] # print("cu_seqlens_q: ", cu_seqlens_q) @@ -487,6 +494,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # print("seqused_q: ", seqused_q) # print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, @@ -505,6 +513,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): attention_chunk=attention_chunk, softcap=softcap, varlen_sort_batches=varlen_sort_batches, + pack_gqa=pack_gqa, + num_splits=num_splits, + head_swizzle=head_swizzle, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -605,12 +616,12 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [False]) -# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) -@pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(False, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) @@ -636,6 +647,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize("varlen_sort_batches", [False, True]) # @pytest.mark.parametrize("varlen_sort_batches", [True]) +@pytest.mark.parametrize("head_swizzle", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -673,6 +685,7 @@ def test_flash_attn_kvcache( mha_type, dtype, varlen_sort_batches, + head_swizzle, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() @@ -698,8 +711,8 @@ def test_flash_attn_kvcache( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] - attention_chunk_vals = [0] # debug + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + # attention_chunk_vals = [0] # debug # dv_vals = [d] # debug for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") @@ -932,6 +945,7 @@ def test_flash_attn_kvcache( num_splits=num_splits, return_softmax_lse=True, varlen_sort_batches=varlen_sort_batches, + head_swizzle=head_swizzle, ) if varlen_q: out = output_pad_fn(out) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 82e3b4cd480..31a493ab00f 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -652,17 +652,15 @@ class VarlenDynamicPersistentTileScheduler { int mh_block = next_tile_idx - group_start_tile; int block, bidh; if constexpr (LPT) { - // if constexpr(!Split) { - if (!Split || num_splits == 1) { - // NOTE: code for computing nheads_in_l2 directly left as demonstration + if ((!Split || num_splits == 1) && params.num_nheads_in_l2_ptr) { + // NOTE: code for computing nheads_in_l2 directly left as reference // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } // nheads_in_l2 = min(nheads_in_l2, params.num_head); - - int nheads_in_l2 = params.num_nheads_in_l2_ptr ? params.num_nheads_in_l2_ptr[bidb] : params.num_head; + int nheads_in_l2 = params.num_nheads_in_l2_ptr[bidb]; int mh_in_l2 = nheads_in_l2 * num_m_blocks; int section_idx = mh_block / mh_in_l2; int l2_mod = mh_block - section_idx * mh_in_l2; @@ -672,27 +670,26 @@ class VarlenDynamicPersistentTileScheduler { block = l2_mod / nheads_in_this_section; int bidh_residual = l2_mod - block * nheads_in_this_section; bidh = section_idx * nheads_in_l2 + bidh_residual; - if constexpr(Split) { // remember to set num_splits = 1 in work tile uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); bidh = reinterpret_cast(bidh_packed); } } else { - // bidh = mh_block / num_m_blocks; - // block = mh_block - bidh * num_m_blocks; + // NOTE: leave traverse heads first version for reference + // block = params.head_divmod.divmod(bidh, mh_block); // if constexpr (Split) { - // int bidh_actual = bidh / num_splits; - // int split_idx = bidh - bidh_actual * num_splits; - // uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // int split_idx = block / num_m_blocks; + // block = block - split_idx * num_m_blocks; + // uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); // bidh = reinterpret_cast(bidh_packed); // } - // assume enough space in L2 cache if split - block = params.head_divmod.divmod(bidh, mh_block); + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; if constexpr (Split) { - int split_idx = block / num_m_blocks; - block = block - split_idx * num_m_blocks; - uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); bidh = reinterpret_cast(bidh_packed); } } From 044697f9158ddfee69ccafbfeb931f6f1cd3764e Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 15:59:14 -0700 Subject: [PATCH 07/20] add correction term for sort when causal --- hopper/flash.h | 2 +- hopper/flash_api.cpp | 57 ++++++++++++++++-------------- hopper/flash_fwd_launch_template.h | 2 +- hopper/flash_prepare_scheduler.cu | 32 +++++++++++------ 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 03f66bbc186..e3413cc032e 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -216,7 +216,7 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl, bool sort_batches); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 14bbc569590..3d66eaf8e9f 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -587,8 +587,8 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= PREPARE_VARLEN_MAX_BATCHES; - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + bool const use_prepare_varlen = params.b <= PREPARE_VARLEN_MAX_BATCHES; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -607,23 +607,26 @@ mha_fwd_get_scheduler_metadata( bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; params.varlen_sort_batches = sort_batches; - if (scheduler_needs_semaphore || use_dynamic_split) { + params.head_swizzle = params.is_causal || params.is_local; + if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_dynamic_split ? 2 : 0; - if(sort_batches) { num_prepare_batch_vectors += 2; } + int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 2; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 1); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); tile_count_semaphore = torch::empty( {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, opts.dtype(torch::kInt32)); - // {num_splits_dynamic, num_m_blocks, virtual_batch_indices, tile_count_semaphore} - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() : nullptr; - // params.num_n_blocks_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.num_nheads_in_l2_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; - params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 3 : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; @@ -636,7 +639,7 @@ mha_fwd_get_scheduler_metadata( int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); auto stream = at::cuda::getCurrentCUDAStream().stream(); - prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/, sort_batches); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } return tile_count_semaphore; @@ -958,9 +961,9 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES; + bool const use_prepare_varlen = is_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -974,10 +977,13 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); params.varlen_sort_batches = sort_batches; - if (scheduler_needs_semaphore || use_dynamic_split) { + params.head_swizzle = head_swizzle; + if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_dynamic_split ? 2 : 0; - if(sort_batches) { num_prepare_batch_vectors += 2; } + int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 2; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 1); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); @@ -992,19 +998,18 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } else { tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); } - if (scheduler_needs_semaphore && !use_dynamic_split) { + if (scheduler_needs_semaphore && !use_prepare_varlen) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } - // {num_splits_dynamic, num_m_blocks, virtual_batch_indices, tile_count_semaphore} - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() : nullptr; - // params.num_n_blocks_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.num_nheads_in_l2_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.num_m_blocks_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; - params.varlen_batch_idx_ptr = use_dynamic_split && sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 3 : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } - params.head_swizzle = head_swizzle; if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 96e6e4d67f6..ef205be4909 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -158,7 +158,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/, params.varlen_sort_batches); + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index be3bd0de610..95750af51d5 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -60,6 +60,7 @@ __global__ void prepare_varlen_num_blocks_kernel( // int* const num_n_blocks_ptr, int* const num_nheads_in_l2_ptr, bool enable_pdl, + bool is_causal, bool packgqa, int max_kvblocks_in_l2) { @@ -161,13 +162,16 @@ __global__ void prepare_varlen_num_blocks_kernel( if constexpr(Sort) { if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { - num_n_blocks = -1; // sort last + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; } int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread - batch_coords[0] = make_int4(num_n_blocks, num_splits_dynamic, num_m_blocks, batch_idx); + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); // if (threadIdx.x == 0) { - // printf("Unsorted: num_n_blocks = %d, num_splits = %d, num_m_blocks = %d, batch_idx = %d.\n", + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); // } __syncthreads(); @@ -175,21 +179,26 @@ __global__ void prepare_varlen_num_blocks_kernel( BlockMergeSort(temp_storage).Sort(batch_coords, CustomMore()); // if (threadIdx.x == 0) { - // printf("Sorted: num_n_blocks = %d, num_splits = %d, num_m_blocks = %d, batch_idx = %d.\n", + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); - // } __syncthreads(); + // } __syncthreads(); + + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } if (threadIdx.x < num_batch) { // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); - num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); - num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].y; - num_m_blocks_ptr[threadIdx.x] = batch_coords[0].z; + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[threadIdx.x] = batch_coords[0].y; + num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].z; varlen_batch_idx_ptr[threadIdx.x] = batch_coords[0].w; } } else { if (batch_idx < num_batch && lane < kNumBatchPerWarp) { // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); - num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } @@ -199,7 +208,7 @@ __global__ void prepare_varlen_num_blocks_kernel( } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, - int blockM, int blockN, bool enable_pdl, bool sort_batches) { + int blockM, int blockN, bool enable_pdl) { // Only support batch <= 992 (32 warps, each with 31 batches) // int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); @@ -210,7 +219,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; - BOOL_SWITCH(sort_batches, Sort, [&] { + BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { NUM_WARP_SWITCH(num_warps, NumWarps, [&] { flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 32 * NumWarps /*block*/, 0, stream>>>( params.seqlen_q, params.seqlen_k, params.seqlen_knew, @@ -225,6 +234,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo // params.num_n_blocks_ptr, params.num_nheads_in_l2_ptr, enable_pdl, + params.is_causal, packgqa, max_kvblocks_in_l2); }); From de166fd399dbc20edc2ec83f118499cea1085a55 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 17:30:49 -0700 Subject: [PATCH 08/20] remove ablation options from frontend and clean up comments --- hopper/flash_api.cpp | 16 ++--- hopper/flash_attn_interface.py | 18 ----- hopper/flash_fwd_launch_template.h | 4 +- hopper/test_flash_attn.py | 101 ++++++++++++++--------------- 4 files changed, 56 insertions(+), 83 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 3d66eaf8e9f..8a1b2dd7337 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -527,7 +527,6 @@ mha_fwd_get_scheduler_metadata( bool has_softcap, int64_t num_splits, std::optional pack_gqa_, - bool sort_batches, int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, @@ -606,7 +605,7 @@ mha_fwd_get_scheduler_metadata( at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - params.varlen_sort_batches = sort_batches; + params.varlen_sort_batches = !params.is_local; params.head_swizzle = params.is_causal || params.is_local; if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers @@ -688,9 +687,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql std::optional scheduler_metadata_, // (b + 1) int64_t num_splits, std::optional pack_gqa_, - bool sort_batches, - int64_t sm_margin, - bool head_swizzle + int64_t sm_margin ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -976,8 +973,8 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - params.varlen_sort_batches = sort_batches; - params.head_swizzle = head_swizzle; + params.varlen_sort_batches = !params.is_local; + params.head_swizzle = params.is_causal || params.is_local; if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; @@ -1013,6 +1010,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); @@ -1688,8 +1686,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "Tensor? scheduler_metadata = None," "int num_splits = 0," "bool? pack_gqa = None," - "bool sort_batches = False," - "int sm_margin = 0, bool head_swizzle = False) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); m.def("bwd(" "Tensor dout," "Tensor q," @@ -1742,7 +1739,6 @@ TORCH_LIBRARY(flash_attn_3, m) { "bool has_softcap = False," "int num_splits = 0," "bool? pack_gqa = None," - "bool sort_batches = False," "int sm_margin = 0) -> Tensor"); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 0c7cc828639..a6deb5cea49 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -51,8 +51,6 @@ def _flash_attn_forward( num_splits=1, pack_gqa=None, sm_margin=0, - varlen_sort_batches=False, - head_swizzle=False, ): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v @@ -99,9 +97,7 @@ def _flash_attn_forward( scheduler_metadata, num_splits, pack_gqa, - varlen_sort_batches, sm_margin, - head_swizzle, ) return out, softmax_lse, *rest @@ -367,8 +363,6 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, - varlen_sort_batches=False, - head_swizzle=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -398,8 +392,6 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, - varlen_sort_batches=varlen_sort_batches, - head_swizzle=head_swizzle, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -608,8 +600,6 @@ def flash_attn_varlen_func( pack_gqa=None, deterministic=False, sm_margin=0, - varlen_sort_batches=False, - head_swizzle=False, ): return FlashAttnVarlenFunc.apply( q, @@ -632,8 +622,6 @@ def flash_attn_varlen_func( pack_gqa, deterministic, sm_margin, - varlen_sort_batches, - head_swizzle ) @@ -672,8 +660,6 @@ def flash_attn_with_kvcache( pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, - varlen_sort_batches=False, - head_swizzle=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -801,8 +787,6 @@ def flash_attn_with_kvcache( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, - varlen_sort_batches=varlen_sort_batches, - head_swizzle=head_swizzle, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out @@ -825,7 +809,6 @@ def get_scheduler_metadata( num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication - varlen_sort_batches=False, ): cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: @@ -847,7 +830,6 @@ def get_scheduler_metadata( has_softcap, num_splits, pack_gqa, - varlen_sort_batches, sm_margin, ) return scheduler_metadata diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index ef205be4909..de94d1ec143 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -58,7 +58,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/, Is_causal /*LPT*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, Is_causal || Is_local /*LPT*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -154,7 +154,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_splits_dynamic_ptr, params.num_m_blocks_ptr, params.varlen_batch_idx_ptr, - params.head_swizzle ? params.num_nheads_in_l2_ptr : nullptr // can toggle for swizzle ablation testing + params.num_nheads_in_l2_ptr }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 14ad659bbd4..7ad02765264 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -39,22 +39,22 @@ DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" -# DISABLE_BACKWARD = True +DISABLE_BACKWARD = True # DISABLE_SPLIT = True # DISABLE_PAGEDKV = True -# DISABLE_APPENDKV = True +DISABLE_APPENDKV = True # DISABLE_LOCAL = True # DISABLE_SOFTCAP = True # DISABLE_PACKGQA = True -DISABLE_FP16 = True +# DISABLE_FP16 = True # DISABLE_FP8 = True # DISABLE_VARLEN = True -DISABLE_CLUSTER = True -DISABLE_HDIM64 = True -DISABLE_HDIM96 = True +# DISABLE_CLUSTER = True +# DISABLE_HDIM64 = True +# DISABLE_HDIM96 = True # DISABLE_HDIM128 = True -DISABLE_HDIM192 = True -DISABLE_HDIM256 = True +# DISABLE_HDIM192 = True +# DISABLE_HDIM256 = True DISABLE_SM8x = True COMPILED_HDIMS = ( @@ -68,21 +68,21 @@ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) -@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) -# @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -# @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +@pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("V_colmajor", [False, True]) @pytest.mark.parametrize("V_colmajor", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -92,8 +92,8 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) +@pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -125,6 +125,8 @@ def test_flash_attn_output( ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(0) @@ -139,8 +141,11 @@ def test_flash_attn_output( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -211,6 +216,7 @@ def test_flash_attn_output( pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") out = flash_attn_func( q, k, @@ -299,23 +305,23 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) -# @pytest.mark.parametrize("add_unused_qkv", [False, True]) -@pytest.mark.parametrize("add_unused_qkv", [True]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @@ -323,10 +329,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", COMPILED_HDIMS) -@pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("varlen_sort_batches", [False, True]) -# @pytest.mark.parametrize("varlen_sort_batches", [False]) -@pytest.mark.parametrize("head_swizzle", [False, True]) +@pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -355,8 +358,10 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, varlen_sort_batches, head_swizzle, + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) @@ -374,8 +379,9 @@ def test_flash_attn_varlen_output( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] - attention_chunk_vals = [0] + if has_qv: + dv_vals = [256, 512] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) @@ -512,10 +518,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, - varlen_sort_batches=varlen_sort_batches, pack_gqa=pack_gqa, num_splits=num_splits, - head_swizzle=head_swizzle, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -616,13 +620,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) +# @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) # @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) @@ -643,11 +647,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize("d", [192]) -@pytest.mark.parametrize("varlen_sort_batches", [False, True]) -# @pytest.mark.parametrize("varlen_sort_batches", [True]) -@pytest.mark.parametrize("head_swizzle", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -684,8 +685,6 @@ def test_flash_attn_kvcache( new_kv, mha_type, dtype, - varlen_sort_batches, - head_swizzle, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() @@ -712,8 +711,6 @@ def test_flash_attn_kvcache( if dtype == torch.float8_e4m3fn: dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] - # attention_chunk_vals = [0] # debug - # dv_vals = [d] # debug for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 @@ -908,7 +905,7 @@ def test_flash_attn_kvcache( cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, - num_splits=num_splits, varlen_sort_batches=varlen_sort_batches, + num_splits=num_splits, ) else: scheduler_metadata = None @@ -944,8 +941,6 @@ def test_flash_attn_kvcache( scheduler_metadata=scheduler_metadata, num_splits=num_splits, return_softmax_lse=True, - varlen_sort_batches=varlen_sort_batches, - head_swizzle=head_swizzle, ) if varlen_q: out = output_pad_fn(out) From e491a2a13994a44022bc2baa4cf491af72f2dea0 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 17:31:39 -0700 Subject: [PATCH 09/20] add comments in prepare kernel --- hopper/flash_prepare_scheduler.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 95750af51d5..9297b9aa8fc 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -15,14 +15,9 @@ namespace flash { -// needs (tensor size = batches): -// 1. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx -// 2. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] -// 3. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] - -// Custom comparison functor for descending order +// Sort in descending order template -struct CustomMore +struct PrepareSortOp { __device__ bool operator()(const DataType &lhs, const DataType &rhs) { @@ -30,17 +25,15 @@ struct CustomMore } }; -// Specialization for int2 template <> -struct CustomMore { +struct PrepareSortOp { __device__ bool operator()(const int2& lhs, const int2& rhs) const { return lhs.x > rhs.x; } }; -// Specialization for int4 template <> -struct CustomMore { +struct PrepareSortOp { __device__ bool operator()(const int4& lhs, const int4& rhs) const { return lhs.x > rhs.x; } @@ -176,7 +169,7 @@ __global__ void prepare_varlen_num_blocks_kernel( // } __syncthreads(); // Sort batches by num_n_blocks in descending order - BlockMergeSort(temp_storage).Sort(batch_coords, CustomMore()); + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); // if (threadIdx.x == 0) { // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", @@ -188,6 +181,12 @@ __global__ void prepare_varlen_num_blocks_kernel( batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); } + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx if (threadIdx.x < num_batch) { // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } From ecfc7c349a3735d494c412cbef600bd9219681d8 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 17:34:32 -0700 Subject: [PATCH 10/20] remove debug code and scripts --- hopper/benchmark_prefill.py | 191 ----------------------- hopper/benchmark_prefill_decode.py | 219 --------------------------- hopper/prepare_varlen_bench.py | 234 ----------------------------- hopper/setup.py | 18 --- hopper/test_flash_attn.py | 18 --- 5 files changed, 680 deletions(-) delete mode 100644 hopper/benchmark_prefill.py delete mode 100644 hopper/benchmark_prefill_decode.py delete mode 100644 hopper/prepare_varlen_bench.py diff --git a/hopper/benchmark_prefill.py b/hopper/benchmark_prefill.py deleted file mode 100644 index 096183d932f..00000000000 --- a/hopper/benchmark_prefill.py +++ /dev/null @@ -1,191 +0,0 @@ -from collections import namedtuple -from functools import partial -import math -import os -from typing import NamedTuple -import torch -import torch.nn as nn -import torch.nn.functional as F - -import time - -Timing = NamedTuple('timing', [('mean', float)]) - -from einops import rearrange, repeat - -# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func -# from flash_attn_interface import flash_attn_func as flash_attn_func_v3 -from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3 -from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 - -from triton.testing import do_bench - -cudnn = None -triton_attention = None - -DISABLE_BACKWARD = True - -def time_fwd(func, *args, repeats=10, verbose=True, desc="", **kwargs): - # Warmup - # for _ in range(5): - # func(*args, **kwargs) - # time.sleep(1) - # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] - # s = torch.cuda.Stream() - # s.wait_stream(torch.cuda.current_stream()) - # with torch.cuda.stream(s): - # for _ in range(2): - # out = func(*args, **kwargs) - # torch.cuda.current_stream().wait_stream(s) - # graph = torch.cuda.CUDAGraph() - # with torch.cuda.graph(graph): - # out = func(*args, **kwargs) - # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) - # # return time_f[1].mean - # return time_f[1] - return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) - -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): - if causal: - avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 - else: - if window_size == (-1, -1): - avg_seqlen = seqlen_k - else: - row_idx = torch.arange(seqlen_q, device='cuda') - col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) - avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) - -torch.manual_seed(0) -repeats = 10 -dropout_p = 0.0 -causal = False -dtype = torch.bfloat16 -# dtype = torch.float8_e4m3fn -dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype -device = 'cuda' -verbose = True -varlen = True -page_size = None -softcap = 0.0 -V_colmajor = False -deterministic = False - -# decode_batches = 128 -# prefill_batches = 1 -# batch_size = decode_batches + prefill_batches -# decode_seqlen_k = 8192 if decode_batches > 0 else 0 -# prefill_seqlen_k = 2048 if prefill_batches > 0 else 0 -# # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative -# seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) -# decode_seqlen_q = 1 if decode_batches > 0 else 0 -# prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 -# seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches -# max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) - -time_f = {} -time_b = {} - - -for headdim in [128]: - for head_swizzle in [True, False]: - print("\nHead Swizzle is ", head_swizzle) - for prefill_batches in [16, 8, 4, 2, 1]: - # for prefill_batches in [4, 4, 4]: - batch_size = prefill_batches - # prefill_seqlen_k = (8192 * 8) // prefill_batches - prefill_seqlen_k = (8192 * 4) // prefill_batches - seqlen_k = prefill_seqlen_k - prefill_seqlen_q = prefill_seqlen_k - seqlen_q = prefill_seqlen_q * prefill_batches - max_seqlen_q = prefill_seqlen_q - - # tp_degree=4 - # nheads = 64//tp_degree - # nheads_kv = 8//tp_degree - nheads = 8 - nheads_kv = 8 - headdim_v = headdim - has_qv = False - - # window_size = (128, 0) - window_size = (-1, -1) - # window_size = (seqlen // 2 - 1, 0) - - # print("Window size: ", window_size) - # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") - # print("Head dim: ", headdim) - # print("Batch size: ", batch_size) - # print("Prefill seqlen k: ", prefill_seqlen_k) - # print("Decode seqlen k: ", decode_seqlen_k) - # print("Seqlen k (max): ", seqlen_k) - # print("Prefill seqlen q: ", prefill_seqlen_q) - # print("Decode seqlen q: ", decode_seqlen_q) - # print("Seqlen q (total): ", seqlen_q) - - num_splits = 1 - pack_gqa = False - - # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") - - leftpad_k = None - # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) - q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) - q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] - v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() - v_fa3 = v if not V_colmajor else v_colmajor - qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None - - cu_seqlens_q = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q - cache_seqlens = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k - - # print("q: ", q.shape) - # print("k: ", k.shape) - # print("v: ", v.shape) - # print("cu seqlens q: ", cu_seqlens_q.shape) - # print("cache seqlens: ", cache_seqlens.shape) - # print("cu seqlens q vals: ", cu_seqlens_q) - # print("cache seqlens vals: ", cache_seqlens) - - page_table = None - - # for causal in [False, True]: - for causal in [True]: - print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {prefill_batches = }, {num_splits = }, {head_swizzle = } ###") - # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) - - prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - nFLOPS = prefill_nFLOPS - - bytes_kv = (prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) - bytes_qo = (prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials - bytes = bytes_kv + bytes_qo - print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') - - time.sleep(1) - m1 = time_fwd(flash_attn_func_v3, - q, - k, - v, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - head_swizzle=head_swizzle, - repeats=repeats, verbose=verbose, desc='Fav3') - - time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean - - print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') - time.sleep(2) - diff --git a/hopper/benchmark_prefill_decode.py b/hopper/benchmark_prefill_decode.py deleted file mode 100644 index d5345bb1534..00000000000 --- a/hopper/benchmark_prefill_decode.py +++ /dev/null @@ -1,219 +0,0 @@ -from collections import namedtuple -from functools import partial -import math -import os -from typing import NamedTuple -import torch -import torch.nn as nn -import torch.nn.functional as F - -import time - -Timing = NamedTuple('timing', [('mean', float)]) - -from einops import rearrange, repeat - -# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func -# from flash_attn_interface import flash_attn_func as flash_attn_func_v3 -from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3 -from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 - -from triton.testing import do_bench - -cudnn = None -triton_attention = None - -DISABLE_BACKWARD = True - -def time_fwd(func, *args, repeats=10, verbose=True, desc="", **kwargs): - # Warmup - # for _ in range(5): - # func(*args, **kwargs) - # time.sleep(1) - # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] - # s = torch.cuda.Stream() - # s.wait_stream(torch.cuda.current_stream()) - # with torch.cuda.stream(s): - # for _ in range(2): - # out = func(*args, **kwargs) - # torch.cuda.current_stream().wait_stream(s) - # graph = torch.cuda.CUDAGraph() - # with torch.cuda.graph(graph): - # out = func(*args, **kwargs) - # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) - # # return time_f[1].mean - # return time_f[1] - return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) - -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): - if causal: - avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 - else: - if window_size == (-1, -1): - avg_seqlen = seqlen_k - else: - row_idx = torch.arange(seqlen_q, device='cuda') - col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) - avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) - -torch.manual_seed(0) -repeats = 10 -dropout_p = 0.0 -causal = False -dtype = torch.bfloat16 -# dtype = torch.float8_e4m3fn -dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype -device = 'cuda' -verbose = True -varlen = True -page_size = None -softcap = 0.0 -V_colmajor = False -deterministic = False - -# decode_batches = 128 -# prefill_batches = 1 -# batch_size = decode_batches + prefill_batches -# decode_seqlen_k = 8192 if decode_batches > 0 else 0 -# prefill_seqlen_k = 2048 if prefill_batches > 0 else 0 -# # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative -# seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) -# decode_seqlen_q = 1 if decode_batches > 0 else 0 -# prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 -# seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches -# max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) - -time_f = {} -time_b = {} - -# prefill_first_vals = [False, True] -prefill_first_vals = [False] - -for headdim in [128]: - for prefill_batches in [8]: - # for decode_batches in range(32, 128 + 32, 8): - for decode_batches in [1]: - for head_swizzle in [False, True]: - batch_size = decode_batches + prefill_batches - decode_seqlen_k = 8192 if decode_batches > 0 else 0 - prefill_seqlen_k = 1024 * 8 if prefill_batches > 0 else 0 - # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative - seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) - decode_seqlen_q = 1 if decode_batches > 0 else 0 - prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 - seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches - max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) - - # tp_degree=8 - # nheads = 64//tp_degree - # nheads_kv = 8//tp_degree - nheads = 64 - nheads_kv = 8 - headdim_v = headdim - has_qv = False - - # window_size = (128, 0) - window_size = (-1, -1) - # window_size = (seqlen // 2 - 1, 0) - - # print("Window size: ", window_size) - # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") - # print("Head dim: ", headdim) - # print("Batch size: ", batch_size) - # print("Prefill seqlen k: ", prefill_seqlen_k) - # print("Decode seqlen k: ", decode_seqlen_k) - # print("Seqlen k (max): ", seqlen_k) - # print("Prefill seqlen q: ", prefill_seqlen_q) - # print("Decode seqlen q: ", decode_seqlen_q) - # print("Seqlen q (total): ", seqlen_q) - - num_splits = 1 - pack_gqa = None - - # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") - - if prefill_batches == 0: - this_prefill_first_vals = [False] - else: - this_prefill_first_vals = prefill_first_vals - - for prefill_first in this_prefill_first_vals: - leftpad_k = None - # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) - q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) - q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] - v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() - v_fa3 = v if not V_colmajor else v_colmajor - qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None - - seqlen_q_decode_offset = decode_seqlen_q * decode_batches - seqlen_q_prefill_offset = prefill_seqlen_q * prefill_batches - seqlen_k_decode_offset = decode_seqlen_k * decode_batches - - if prefill_first: - cu_seqlens_q_prefill = torch.arange(prefill_batches, device=device, dtype=torch.int32) * prefill_seqlen_q - cu_seqlens_q_decode = torch.arange(decode_batches + 1, device=device, dtype=torch.int32) * decode_seqlen_q + seqlen_q_prefill_offset - cu_seqlens_q = torch.cat((cu_seqlens_q_prefill, cu_seqlens_q_decode), dim=0) - else: - cu_seqlens_q_decode = torch.arange(decode_batches, device=device, dtype=torch.int32) * decode_seqlen_q - cu_seqlens_q_prefill = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + seqlen_q_decode_offset - cu_seqlens_q = torch.cat((cu_seqlens_q_decode, cu_seqlens_q_prefill), dim=0) - - cache_seqlens_decode = torch.ones(decode_batches, dtype=torch.int32, device=device) * decode_seqlen_k - cache_seqlens_prefill = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k - - if prefill_first: - cache_seqlens = torch.cat((cache_seqlens_prefill, cache_seqlens_decode), dim=0) - else: - cache_seqlens = torch.cat((cache_seqlens_decode, cache_seqlens_prefill), dim=0) - - - # print("q: ", q.shape) - # print("k: ", k.shape) - # print("v: ", v.shape) - # print("cu seqlens q: ", cu_seqlens_q.shape) - # print("cache seqlens: ", cache_seqlens.shape) - # print("cu seqlens q vals: ", cu_seqlens_q) - # print("cache seqlens vals: ", cache_seqlens) - - page_table = None - - # for causal in [False, True]: - for causal in [True]: - print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {decode_seqlen_k = }, {num_splits = }, {prefill_first = }, {decode_batches = }, {prefill_batches = }, {head_swizzle = } ###") - # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) - decode_nFLOPS = flops(decode_batches, nheads, decode_seqlen_q, decode_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - nFLOPS = decode_nFLOPS + prefill_nFLOPS - - bytes_kv = (decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) - bytes_qo = (decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials - bytes = bytes_kv + bytes_qo - print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') - - time.sleep(1) - m1 = time_fwd(flash_attn_func_v3, - q, - k, - v, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - head_swizzle=head_swizzle, - repeats=repeats, verbose=verbose, desc='Fav3') - - time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean - - print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') - diff --git a/hopper/prepare_varlen_bench.py b/hopper/prepare_varlen_bench.py deleted file mode 100644 index c0e9ea9ed14..00000000000 --- a/hopper/prepare_varlen_bench.py +++ /dev/null @@ -1,234 +0,0 @@ -from collections import namedtuple -from functools import partial -import math -import os -from typing import NamedTuple -import torch -import torch.nn as nn -import torch.nn.functional as F - -import time - -Timing = NamedTuple('timing', [('mean', float)]) - -from einops import rearrange, repeat - -# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func -# from flash_attn_interface import flash_attn_func as flash_attn_func_v3 -from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3, get_scheduler_metadata -from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 - -from triton.testing import do_bench - -cudnn = None -triton_attention = None - -DISABLE_BACKWARD = True - -def time_fwd(func, *args, repeats=10, verbose=True, desc="", **kwargs): - # Warmup - # for _ in range(5): - # func(*args, **kwargs) - # time.sleep(1) - # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] - # s = torch.cuda.Stream() - # s.wait_stream(torch.cuda.current_stream()) - # with torch.cuda.stream(s): - # for _ in range(2): - # out = func(*args, **kwargs) - # torch.cuda.current_stream().wait_stream(s) - # graph = torch.cuda.CUDAGraph() - # with torch.cuda.graph(graph): - # out = func(*args, **kwargs) - # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) - # # return time_f[1].mean - # return time_f[1] - return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) - -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): - if causal: - avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 - else: - if window_size == (-1, -1): - avg_seqlen = seqlen_k - else: - row_idx = torch.arange(seqlen_q, device='cuda') - col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) - avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) - -torch.manual_seed(0) -repeats = 10 -dropout_p = 0.0 -causal = False -dtype = torch.bfloat16 -# dtype = torch.float8_e4m3fn -dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype -device = 'cuda' -verbose = True -varlen = True -page_size = None -softcap = 0.0 -V_colmajor = False -deterministic = False - -# decode_batches = 128 -# prefill_batches = 1 -# batch_size = decode_batches + prefill_batches -# decode_seqlen_k = 8192 if decode_batches > 0 else 0 -# prefill_seqlen_k = 2048 if prefill_batches > 0 else 0 -# # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative -# seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) -# decode_seqlen_q = 1 if decode_batches > 0 else 0 -# prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 -# seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches -# max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) - -time_f = {} -time_b = {} - -prefill_first_vals = [False, True] -# prefill_first_vals = [False] - -for headdim in [128]: - for prefill_batches in [0]: - # for decode_batches in range(32, 128 + 32, 8): - for decode_batches in [128]: - for sort_batches in [False, True]: - - batch_size = decode_batches + prefill_batches - decode_seqlen_k = 8192 if decode_batches > 0 else 0 - prefill_seqlen_k = 1024 if prefill_batches > 0 else 0 - # seqlen_k = decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches # for cumulative - seqlen_k = max(decode_seqlen_k, prefill_seqlen_k) - decode_seqlen_q = 1 if decode_batches > 0 else 0 - prefill_seqlen_q = prefill_seqlen_k if prefill_batches > 0 else 0 - seqlen_q = decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches - max_seqlen_q = max(decode_seqlen_q, prefill_seqlen_q) - - tp_degree=1 - nheads = 64//tp_degree - nheads_kv = 8//tp_degree - # nheads = 1 - # nheads_kv = 1 - headdim_v = headdim - has_qv = False - - # window_size = (128, 0) - window_size = (-1, -1) - # window_size = (seqlen // 2 - 1, 0) - - # print("Window size: ", window_size) - # print(f"Num query heads = {nheads}, kv heads = {nheads_kv}") - # print("Head dim: ", headdim) - # print("Batch size: ", batch_size) - # print("Prefill seqlen k: ", prefill_seqlen_k) - # print("Decode seqlen k: ", decode_seqlen_k) - # print("Seqlen k (max): ", seqlen_k) - # print("Prefill seqlen q: ", prefill_seqlen_q) - # print("Decode seqlen q: ", decode_seqlen_q) - # print("Seqlen q (total): ", seqlen_q) - - num_splits = 1 - pack_gqa = None - - # print(f"Num splits = {num_splits}, Pack GQA = {pack_gqa}") - - if prefill_batches == 0: - this_prefill_first_vals = [False] - else: - this_prefill_first_vals = prefill_first_vals - - for prefill_first in this_prefill_first_vals: - leftpad_k = None - # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) - q = torch.randn(seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) - q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] - v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() - v_fa3 = v if not V_colmajor else v_colmajor - qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None - - seqlen_q_decode_offset = decode_seqlen_q * decode_batches - seqlen_q_prefill_offset = prefill_seqlen_q * prefill_batches - seqlen_k_decode_offset = decode_seqlen_k * decode_batches - - if prefill_first: - cu_seqlens_q_prefill = torch.arange(prefill_batches, device=device, dtype=torch.int32) * prefill_seqlen_q - cu_seqlens_q_decode = torch.arange(decode_batches + 1, device=device, dtype=torch.int32) * decode_seqlen_q + seqlen_q_prefill_offset - cu_seqlens_q = torch.cat((cu_seqlens_q_prefill, cu_seqlens_q_decode), dim=0) - else: - cu_seqlens_q_decode = torch.arange(decode_batches, device=device, dtype=torch.int32) * decode_seqlen_q - cu_seqlens_q_prefill = torch.arange(prefill_batches + 1, device=device, dtype=torch.int32) * prefill_seqlen_q + seqlen_q_decode_offset - cu_seqlens_q = torch.cat((cu_seqlens_q_decode, cu_seqlens_q_prefill), dim=0) - - cache_seqlens_decode = torch.ones(decode_batches, dtype=torch.int32, device=device) * decode_seqlen_k - cache_seqlens_prefill = torch.ones(prefill_batches, dtype=torch.int32, device=device) * prefill_seqlen_k - - if prefill_first: - cache_seqlens = torch.cat((cache_seqlens_prefill, cache_seqlens_decode), dim=0) - else: - cache_seqlens = torch.cat((cache_seqlens_decode, cache_seqlens_prefill), dim=0) - - - # print("q: ", q.shape) - # print("k: ", k.shape) - # print("v: ", v.shape) - # print("cu seqlens q: ", cu_seqlens_q.shape) - # print("cache seqlens: ", cache_seqlens.shape) - # print("cu seqlens q vals: ", cu_seqlens_q) - # print("cache seqlens vals: ", cache_seqlens) - - page_table = None - - # for causal in [False, True]: - for causal in [True]: - print(f"\n### {headdim = }, {nheads = }, {nheads_kv = }, {causal = }, {prefill_seqlen_k = }, {decode_seqlen_k = }, {num_splits = }, {prefill_first = }, {decode_batches = }, {prefill_batches = } ###") - # nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) - decode_nFLOPS = flops(decode_batches, nheads, decode_seqlen_q, decode_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - prefill_nFLOPS = flops(prefill_batches, nheads, prefill_seqlen_q, prefill_seqlen_k, headdim, headdim_v, causal=causal, window_size=window_size) - nFLOPS = decode_nFLOPS + prefill_nFLOPS - - bytes_kv = (decode_seqlen_k * decode_batches + prefill_seqlen_k * prefill_batches) * (nheads_kv * headdim * 4) - bytes_qo = (decode_seqlen_q * decode_batches + prefill_seqlen_q * prefill_batches) * (nheads * headdim * 4) # don't count split partials - bytes = bytes_kv + bytes_qo - # print(f'{nFLOPS * 1e-9:.1f} GFLOPs, {bytes * 1e-9: .2f} GB, {nFLOPS/bytes:.1f} AI') - - # time.sleep(1) - # m1 = time_fwd(flash_attn_func_v3, - # q, - # k, - # v, - # cache_seqlens=cache_seqlens, - # cu_seqlens_q=cu_seqlens_q, - # max_seqlen_q=max_seqlen_q, - # causal=causal, - # window_size=window_size, - # softcap=softcap, - # num_splits=num_splits, - # pack_gqa=pack_gqa, - # repeats=repeats, verbose=verbose, desc='Fav3') - - # time_f[(causal, headdim, batch_size, seqlen_k), "Flash3"] = m1.mean - - # print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS, {bytes/ m1.mean * 1e-9: .2f} GB/s') - - scheduler_metadata = get_scheduler_metadata( - batch_size, max_seqlen_q, seqlen_k, nheads, nheads_kv, headdim, - cache_seqlens, q.dtype, headdim_v=headdim, cu_seqlens_q=cu_seqlens_q, - causal=causal, num_splits=num_splits, varlen_sort_batches=sort_batches, ) - - # m1 = time_fwd(get_scheduler_metadata, - # batch_size, max_seqlen_q, seqlen_k, nheads, nheads_kv, headdim, - # cache_seqlens, q.dtype, headdim_v=headdim, cu_seqlens_q=cu_seqlens_q, - # causal=causal, num_splits=num_splits, sort_batches=sort_batches, - # repeats=repeats, verbose=verbose, desc='Prepare' - # ) - - # time_f[(causal, headdim, batch_size, seqlen_k), "Prepare"] = m1.mean - - # print(f'Prepare: {m1.mean * 1e3:.3f}ms') diff --git a/hopper/setup.py b/hopper/setup.py index 4b4b47e7587..8142107edc7 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -62,24 +62,6 @@ DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" DISABLE_SM8x = os.getenv("FLASH_ATTENTION_DISABLE_SM80", "FALSE") == "TRUE" -# DISABLE_BACKWARD = True -# DISABLE_SPLIT = True -# DISABLE_PAGEDKV = True -# DISABLE_APPENDKV = True -# DISABLE_LOCAL = True -# DISABLE_SOFTCAP = True -# DISABLE_PACKGQA = True -DISABLE_FP16 = True -# DISABLE_FP8 = True -# DISABLE_VARLEN = True -DISABLE_CLUSTER = True -DISABLE_HDIM64 = True -DISABLE_HDIM96 = True -# DISABLE_HDIM128 = True -DISABLE_HDIM192 = True -DISABLE_HDIM256 = True -DISABLE_SM8x = True - ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 7ad02765264..76608791543 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -39,24 +39,6 @@ DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" -DISABLE_BACKWARD = True -# DISABLE_SPLIT = True -# DISABLE_PAGEDKV = True -DISABLE_APPENDKV = True -# DISABLE_LOCAL = True -# DISABLE_SOFTCAP = True -# DISABLE_PACKGQA = True -# DISABLE_FP16 = True -# DISABLE_FP8 = True -# DISABLE_VARLEN = True -# DISABLE_CLUSTER = True -# DISABLE_HDIM64 = True -# DISABLE_HDIM96 = True -# DISABLE_HDIM128 = True -# DISABLE_HDIM192 = True -# DISABLE_HDIM256 = True -DISABLE_SM8x = True - COMPILED_HDIMS = ( [] + ([64] if not DISABLE_HDIM64 else []) From 495487d142f5342ec265494c7e9e2a1dc2add084 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 17:44:05 -0700 Subject: [PATCH 11/20] put back defaults in tests --- hopper/flash_prepare_scheduler.cu | 2 +- hopper/test_flash_attn.py | 34 +++++++++++++++---------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 9297b9aa8fc..06f2599ac3b 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -47,7 +47,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, - int* const num_m_blocks_ptr, // virtual_batch_idx -> num_m_blocks[batch_idx] + int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr, int* const varlen_batch_idx_ptr, // int* const num_n_blocks_ptr, diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 76608791543..8ea2dee2897 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -50,21 +50,21 @@ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("has_qv", [False, True]) # @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) -@pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -@pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("V_colmajor", [False, True]) @pytest.mark.parametrize("V_colmajor", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -74,8 +74,8 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -# @pytest.mark.parametrize("d", COMPILED_HDIMS) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -296,8 +296,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) @@ -310,8 +310,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) -# @pytest.mark.parametrize("d", COMPILED_HDIMS) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -629,7 +629,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", From aadca4ede71fb8af3ee5a41cfd479b037f2442dc Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 18:02:29 -0700 Subject: [PATCH 12/20] remove excess Nones returned in python interface for varlen --- hopper/flash_attn_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index a6deb5cea49..cae67446668 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -437,7 +437,7 @@ def backward(ctx, dout, *args): dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, def flash_attn_qkvpacked_func( From 16d23554cfca400b340adae598e0fcb3384627a1 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 18:04:47 -0700 Subject: [PATCH 13/20] revert opinionated change to setup.py on cuda version 12.9 --- hopper/flash_attn_interface.py | 2 +- hopper/setup.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index cae67446668..a2eb9594896 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -437,7 +437,7 @@ def backward(ctx, dout, *args): dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( diff --git a/hopper/setup.py b/hopper/setup.py index 8142107edc7..c15c438f56c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -401,8 +401,7 @@ def nvcc_threads_args(): # ptxas 12.8 gives the best perf currently # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. - if bare_metal_version != Version("12.8") and bare_metal_version != Version("12.9"): - print("Bare Metal Version is: ", bare_metal_version) + if bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", From 3349e11aa968abee5438464fd78dc36e21ec77ec Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 18 Aug 2025 19:06:56 -0700 Subject: [PATCH 14/20] force inline sort op and make east const --- hopper/flash_prepare_scheduler.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 06f2599ac3b..37e56b7a334 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -16,10 +16,10 @@ namespace flash { // Sort in descending order -template +template struct PrepareSortOp { - __device__ bool operator()(const DataType &lhs, const DataType &rhs) + __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs) { return lhs > rhs; } @@ -27,14 +27,14 @@ struct PrepareSortOp template <> struct PrepareSortOp { - __device__ bool operator()(const int2& lhs, const int2& rhs) const { + __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const { return lhs.x > rhs.x; } }; template <> struct PrepareSortOp { - __device__ bool operator()(const int4& lhs, const int4& rhs) const { + __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const { return lhs.x > rhs.x; } }; From 4d9f78c49e79eabff072464c2d9a3dc314fb8be0 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Tue, 19 Aug 2025 16:34:28 -0700 Subject: [PATCH 15/20] more templating in varlen scheduler to cure some register spilling --- hopper/flash.h | 1 + hopper/flash_api.cpp | 10 ++++++---- hopper/flash_fwd_launch_template.h | 32 ++++++++++++++++-------------- hopper/tile_scheduler.hpp | 32 ++++++++++++++++++------------ hopper/tile_size.h | 7 ++++--- 5 files changed, 47 insertions(+), 35 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index e3413cc032e..d68a317ffd9 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -161,6 +161,7 @@ struct Flash_fwd_params : public Qkv_params { bool varlen_sort_batches; int tile_count_semaphore_offset; bool head_swizzle; + bool use_prepare_varlen; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 8a1b2dd7337..3e853d6bfaa 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -587,6 +587,7 @@ mha_fwd_get_scheduler_metadata( params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); bool const use_prepare_varlen = params.b <= PREPARE_VARLEN_MAX_BATCHES; + params.use_prepare_varlen = use_prepare_varlen; params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); @@ -605,8 +606,8 @@ mha_fwd_get_scheduler_metadata( at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - params.varlen_sort_batches = !params.is_local; - params.head_swizzle = params.is_causal || params.is_local; + params.varlen_sort_batches = !params.is_local; // Use for Sort value in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use for LPT value in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; @@ -959,6 +960,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel bool const use_prepare_varlen = is_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES; + params.use_prepare_varlen = use_prepare_varlen; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); @@ -973,8 +975,8 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - params.varlen_sort_batches = !params.is_local; - params.head_swizzle = params.is_causal || params.is_local; + params.varlen_sort_batches = !params.is_local; // Use for Sort value in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use for LPT value in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index de94d1ec143..0ad25cf6cf4 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -26,7 +26,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, bool PrepareVarlen> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -58,7 +58,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/, Is_causal || Is_local /*LPT*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, Is_causal || Is_local /*LPT*/, PrepareVarlen, !Is_local /*Sort*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -157,7 +157,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_nheads_in_l2_ptr }; - if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { + if (PrepareVarlen && !params.skip_scheduler_metadata_computation) { prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -191,7 +191,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } // kernel<<>>(kernel_params); cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + Arch >= 90 && PrepareVarlen && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } @@ -205,17 +205,19 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + BOOL_SWITCH(params.use_prepare_varlen, PrepareVarlen_, [&] { + static constexpr bool PrepareVarlen = PrepareVarlen_ && Varlen; + // Only needed here to decide if we should use cluster + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 31a493ab00f..9dfbf618bd2 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -466,7 +466,7 @@ class SingleTileBwdLPTScheduler { /////////////////////////////////////////////////////////////////////////////// -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -539,9 +539,12 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { - int actual_bidb = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[bidb] : bidb; + auto get_actual_batch = [&](int virtual_batch) { + if constexpr(Prepared && Sort) { return params.varlen_batch_idx_ptr[virtual_batch]; } + else { return virtual_batch; } + }; if constexpr (!Split) { - return {block, bidh, actual_bidb, 0 /*split_idx*/}; + return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; } else { // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift @@ -555,7 +558,7 @@ class VarlenDynamicPersistentTileScheduler { // if (threadIdx.x == 128) { // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); // } - return {block, bidh_actual, actual_bidb, split_idx}; + return {block, bidh_actual, get_actual_batch(bidb), split_idx}; } } }; @@ -569,9 +572,9 @@ class VarlenDynamicPersistentTileScheduler { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - if (params.num_m_blocks_ptr) { - int num_m_blocks = batch_idx < params.num_batch ? params.num_m_blocks_ptr[batch_idx] : 0; - return lane < cutlass::NumThreadsPerWarp - 1 ? num_m_blocks : 0; + if constexpr (Prepared && Sort) { + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? params.num_m_blocks_ptr[batch_idx] : 0; } else { int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); if (seqlen > kBlockM) { @@ -594,11 +597,14 @@ class VarlenDynamicPersistentTileScheduler { auto get_num_splits = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : (params.num_splits_dynamic_ptr - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor)) - : 0; + bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1; + if constexpr (!Split) { + return is_valid ? 1 : 0; + } else if constexpr(Prepared) { + return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; + } else { + return is_valid ? params.nsplits_divmod.divisor : 0; + } }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane @@ -652,7 +658,7 @@ class VarlenDynamicPersistentTileScheduler { int mh_block = next_tile_idx - group_start_tile; int block, bidh; if constexpr (LPT) { - if ((!Split || num_splits == 1) && params.num_nheads_in_l2_ptr) { + if (Prepared && (!Split || num_splits == 1)) { // NOTE: code for computing nheads_in_l2 directly left as reference // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; diff --git a/hopper/tile_size.h b/hopper/tile_size.h index e6cb31515c7..8353542c477 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -21,7 +21,7 @@ constexpr std::tuple tile_size_fwd_sm90( return {128, 96, true, false}; } else { // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen @@ -29,8 +29,9 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - // {128, 192, false, false} and {192, 128, false, true} are quite good too + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {128, use_blockN_128 ? 128 : 176, true, true}; + // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem From b04a82de33a487ee6d571d7458335ab8f5e8664d Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Tue, 19 Aug 2025 20:57:02 -0700 Subject: [PATCH 16/20] fix exploding build by splitting compilation and add qol macros for hdimdiff --- hopper/flash_api.cpp | 10 ++++++++ hopper/flash_fwd_launch_template.h | 13 ++++++---- hopper/setup.py | 26 +++++++++++++++++++- hopper/tile_scheduler.hpp | 38 +++++++++++++++++++++++++----- 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 3e853d6bfaa..09279a17f3a 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -252,6 +252,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -259,6 +260,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -270,11 +272,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -285,6 +289,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -292,6 +297,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -303,11 +309,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -331,11 +339,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } + #endif return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 0ad25cf6cf4..db600c519fe 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -57,8 +57,11 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + static constexpr bool LPT = Is_causal || Is_local; + // condition since we may want to sort without PrepareVarlen being true when arch is sm8x + static constexpr bool Sort = cute::conditional_return= 90>(!Is_local && PrepareVarlen, !Is_local); using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/, Is_causal || Is_local /*LPT*/, PrepareVarlen, !Is_local /*Sort*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, LPT, PrepareVarlen, Sort>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -157,7 +160,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_nheads_in_l2_ptr }; - if (PrepareVarlen && !params.skip_scheduler_metadata_computation) { + if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -191,7 +194,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } // kernel<<>>(kernel_params); cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && PrepareVarlen && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } @@ -206,7 +209,9 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { BOOL_SWITCH(params.use_prepare_varlen, PrepareVarlen_, [&] { - static constexpr bool PrepareVarlen = PrepareVarlen_ && Varlen; + // If arch is sm8x, don't compile for the PrepareVarlen option to save time. + // For sm8x we only use varlen dynamic persistent scheduler with split case anyway and memory-bound kernels aren't as impacted. + static constexpr bool PrepareVarlen = PrepareVarlen_ && Varlen && Arch >= 90; // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; diff --git a/hopper/setup.py b/hopper/setup.py index c15c438f56c..850fb0b520c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -64,6 +64,8 @@ ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', @@ -468,10 +470,13 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] @@ -481,7 +486,18 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all", "diff"] + # build will now explode with this compilation grouping given all our templating + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) @@ -495,6 +511,14 @@ def nvcc_threads_args(): sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 9dfbf618bd2..0fd64e068ac 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -466,7 +466,8 @@ class SingleTileBwdLPTScheduler { /////////////////////////////////////////////////////////////////////////////// -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -540,8 +541,14 @@ class VarlenDynamicPersistentTileScheduler { cute::tuple get_block_coord(Params const& params) const { auto get_actual_batch = [&](int virtual_batch) { - if constexpr(Prepared && Sort) { return params.varlen_batch_idx_ptr[virtual_batch]; } - else { return virtual_batch; } + if constexpr(Prepared && Sort) { + return params.varlen_batch_idx_ptr[virtual_batch]; + } else if constexpr (Sort) { + // use conditional for sm8x kernels + return params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[virtual_batch] : virtual_batch; + } else { + return virtual_batch; + } }; if constexpr (!Split) { return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; @@ -576,6 +583,13 @@ class VarlenDynamicPersistentTileScheduler { return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 ? params.num_m_blocks_ptr[batch_idx] : 0; } else { + if constexpr(!Prepared && Sort) { + // use conditional for sm8x kernels + if (params.num_m_blocks_ptr) { + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? params.num_m_blocks_ptr[batch_idx] : 0; + } + } int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); if (seqlen > kBlockM) { if (params.seqused) { @@ -603,7 +617,11 @@ class VarlenDynamicPersistentTileScheduler { } else if constexpr(Prepared) { return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; } else { - return is_valid ? params.nsplits_divmod.divisor : 0; + return is_valid + ? (params.num_splits_dynamic_ptr // use conditional for sm8x kernels + ? params.num_splits_dynamic_ptr[batch_idx] + : params.nsplits_divmod.divisor) + : 0; } }; @@ -658,7 +676,7 @@ class VarlenDynamicPersistentTileScheduler { int mh_block = next_tile_idx - group_start_tile; int block, bidh; if constexpr (LPT) { - if (Prepared && (!Split || num_splits == 1)) { + if (!Split || num_splits == 1) { // NOTE: code for computing nheads_in_l2 directly left as reference // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; @@ -666,7 +684,15 @@ class VarlenDynamicPersistentTileScheduler { // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } // nheads_in_l2 = min(nheads_in_l2, params.num_head); - int nheads_in_l2 = params.num_nheads_in_l2_ptr[bidb]; + auto get_nheads_in_l2 = [&](int batch_idx) { + if constexpr(Prepared) { + return params.num_nheads_in_l2_ptr[batch_idx]; + } else { + // use conditional for sm8x kernels + return params.num_nheads_in_l2_ptr ? params.num_nheads_in_l2_ptr[batch_idx] : (!PackGQA ? params.qhead_per_khead : 1); + } + }; + int nheads_in_l2 = get_nheads_in_l2(bidb); int mh_in_l2 = nheads_in_l2 * num_m_blocks; int section_idx = mh_block / mh_in_l2; int l2_mod = mh_block - section_idx * mh_in_l2; From bb1fc099650e32f24d22d1e3e4b04d6da61a2bf0 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Tue, 19 Aug 2025 23:52:28 -0700 Subject: [PATCH 17/20] fix metadata mismatch with seqlenk in test script --- hopper/test_flash_attn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 8ea2dee2897..0b5a0e2af98 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -882,7 +882,10 @@ def test_flash_attn_kvcache( print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: scheduler_metadata = get_scheduler_metadata( - batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + batch_size, + max_seqlen_q if varlen_q else seqlen_q, + seqlen_k if page_size is None else page_table.shape[1] * page_size, + nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, From 0a57ef73389e5f6deb1ab6451c83312e9756ae6d Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Wed, 20 Aug 2025 15:25:13 -0700 Subject: [PATCH 18/20] extend prepare kernel to >992 batches and always call it for varlen --- hopper/flash.h | 2 +- hopper/flash_api.cpp | 37 ++++---- hopper/flash_fwd_launch_template.h | 38 ++++---- hopper/flash_prepare_scheduler.cu | 145 ++++++++++++++++------------- hopper/tile_scheduler.hpp | 25 +---- 5 files changed, 122 insertions(+), 125 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index d68a317ffd9..6848e8c9dbd 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -161,7 +161,7 @@ struct Flash_fwd_params : public Qkv_params { bool varlen_sort_batches; int tile_count_semaphore_offset; bool head_swizzle; - bool use_prepare_varlen; + bool prepare_varlen_pdl; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 09279a17f3a..8fc930c5da6 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -39,7 +39,7 @@ PyObject* PyInit__C(void) #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define PREPARE_VARLEN_MAX_BATCHES 992 +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 void set_params_fprop(Flash_fwd_params ¶ms, // sizes @@ -596,8 +596,8 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_prepare_varlen = params.b <= PREPARE_VARLEN_MAX_BATCHES; - params.use_prepare_varlen = use_prepare_varlen; + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); @@ -616,14 +616,14 @@ mha_fwd_get_scheduler_metadata( at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - params.varlen_sort_batches = !params.is_local; // Use for Sort value in scheduler template - params.head_swizzle = params.is_causal || params.is_local; // Use for LPT value in scheduler template + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; - if(params.varlen_sort_batches) { num_prepare_batch_vectors += 2; } + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } if(params.head_swizzle) { num_prepare_batch_vectors += 1; } - int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 1); + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); tile_count_semaphore = torch::empty( @@ -631,7 +631,7 @@ mha_fwd_get_scheduler_metadata( opts.dtype(torch::kInt32)); // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; - params.num_m_blocks_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; @@ -967,10 +967,9 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } - - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_prepare_varlen = is_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES; - params.use_prepare_varlen = use_prepare_varlen; + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); @@ -985,14 +984,14 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - params.varlen_sort_batches = !params.is_local; // Use for Sort value in scheduler template - params.head_swizzle = params.is_causal || params.is_local; // Use for LPT value in scheduler template + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; - if(params.varlen_sort_batches) { num_prepare_batch_vectors += 2; } + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } if(params.head_swizzle) { num_prepare_batch_vectors += 1; } - int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 1); + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); @@ -1012,7 +1011,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; - params.num_m_blocks_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index db600c519fe..d48a4fd9562 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -26,7 +26,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -58,10 +58,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; static constexpr bool LPT = Is_causal || Is_local; - // condition since we may want to sort without PrepareVarlen being true when arch is sm8x - static constexpr bool Sort = cute::conditional_return= 90>(!Is_local && PrepareVarlen, !Is_local); + static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/, LPT, PrepareVarlen, Sort>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -160,8 +159,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_nheads_in_l2_ptr }; - if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + if (Varlen && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -194,7 +193,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } // kernel<<>>(kernel_params); cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } @@ -208,21 +207,16 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - BOOL_SWITCH(params.use_prepare_varlen, PrepareVarlen_, [&] { - // If arch is sm8x, don't compile for the PrepareVarlen option to save time. - // For sm8x we only use varlen dynamic persistent scheduler with split case anyway and memory-bound kernels aren't as impacted. - static constexpr bool PrepareVarlen = PrepareVarlen_ && Varlen && Arch >= 90; - // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); - }); + // Only needed here to decide if we should use cluster + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 37e56b7a334..a5caa607cf5 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -64,13 +64,11 @@ __global__ void prepare_varlen_num_blocks_kernel( static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); using BlockMergeSort = cub::BlockMergeSort; - // Assume that there's only one block in the grid __shared__ int total_blocks_smem[kSmemSize]; // Allocate shared memory for BlockMergeSort operations __shared__ typename BlockMergeSort::TempStorage temp_storage; - // There's only 1 block in the grid, so might as well start launching the main attn kernel if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } @@ -123,25 +121,11 @@ __global__ void prepare_varlen_num_blocks_kernel( }; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int bidb_start = kNumBatchPerWarp * warp_idx; + int batch_offset = int(blockIdx.x) * 992; + int bidb_start = batch_offset + kNumBatchPerWarp * warp_idx; int batch_idx = lane + bidb_start; int num_m_blocks = get_num_m_blocks(batch_idx); int num_n_blocks = get_num_n_blocks(batch_idx); - int total_blocks = num_m_blocks * num_n_blocks; - // Warp sum - #pragma unroll - for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { - total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); - } - if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } - __syncthreads(); - total_blocks = total_blocks_smem[0]; - // 10% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); - // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - // num_n_blocks per work tile for the batch - num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); auto get_nheads_in_l2 = [&](int n_blocks) { int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 @@ -153,53 +137,89 @@ __global__ void prepare_varlen_num_blocks_kernel( return min(nheads_in_l2, num_head); }; - if constexpr(Sort) { - if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { - num_n_blocks = INT_MIN; // sort last - } else if (is_causal) { - // sort by shortest member to process - num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + if (blockIdx.x > 0) { + // trivially handle excess batches + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[batch_idx] = 1; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + if constexpr (Sort) { + varlen_batch_idx_ptr[batch_idx] = batch_idx; + } + } + } else { + int num_splits_dynamic; + if (int(gridDim.x) > 1 || num_splits_static == 1) { + // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) + // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) + num_splits_dynamic = 1; + } else { + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); } - int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread - batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); - // if (threadIdx.x == 0) { - // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", - // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); - // } __syncthreads(); + if constexpr (Sort) { + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + } + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); - // Sort batches by num_n_blocks in descending order - BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); - // if (threadIdx.x == 0) { - // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", - // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); - // } __syncthreads(); + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); - if (is_causal) { - // reset value to num_n_blocks - batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); - } + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); - // When sorting, we re-index some metadata by 'virtual batch index' - // and also store the vbidx -> bidx mapping. - // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] - // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] - // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] - // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx - if (threadIdx.x < num_batch) { - // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); - if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } - num_m_blocks_ptr[threadIdx.x] = batch_coords[0].y; - num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].z; - varlen_batch_idx_ptr[threadIdx.x] = batch_coords[0].w; - } - } else { - if (batch_idx < num_batch && lane < kNumBatchPerWarp) { - // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); - if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } - num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } + + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx + if (threadIdx.x < min(num_batch, 992)) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[threadIdx.x] = batch_coords[0].y; + num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].z; + varlen_batch_idx_ptr[threadIdx.x] = batch_coords[0].w; + } + } else { + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } } } } @@ -208,10 +228,9 @@ __global__ void prepare_varlen_num_blocks_kernel( void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl) { - // Only support batch <= 992 (32 warps, each with 31 batches) - // int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); - int num_warps = cutlass::ceil_div(params.b, 31); + int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 + int num_ctas = cutlass::ceil_div(params.b, 31 * 32); // int const size_l2 = 50 * 1024 * 1024; // 50 MB int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice int const element_size = params.is_e4m3 ? 1 : 2; @@ -220,7 +239,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { NUM_WARP_SWITCH(num_warps, NumWarps, [&] { - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 32 * NumWarps /*block*/, 0, stream>>>( + flash::prepare_varlen_num_blocks_kernel<<>>( params.seqlen_q, params.seqlen_k, params.seqlen_knew, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, params.leftpad_k, diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 0fd64e068ac..41e0bab1624 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -467,7 +467,7 @@ class SingleTileBwdLPTScheduler { /////////////////////////////////////////////////////////////////////////////// template + bool Split=false, bool PackGQA=false, bool WarpSpecialized=true, bool LPT = false, bool Sort = false, bool Prepared = true> class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -543,10 +543,7 @@ class VarlenDynamicPersistentTileScheduler { auto get_actual_batch = [&](int virtual_batch) { if constexpr(Prepared && Sort) { return params.varlen_batch_idx_ptr[virtual_batch]; - } else if constexpr (Sort) { - // use conditional for sm8x kernels - return params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[virtual_batch] : virtual_batch; - } else { + } else { return virtual_batch; } }; @@ -579,17 +576,10 @@ class VarlenDynamicPersistentTileScheduler { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - if constexpr (Prepared && Sort) { + if constexpr (Prepared) { return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 ? params.num_m_blocks_ptr[batch_idx] : 0; } else { - if constexpr(!Prepared && Sort) { - // use conditional for sm8x kernels - if (params.num_m_blocks_ptr) { - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? params.num_m_blocks_ptr[batch_idx] : 0; - } - } int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); if (seqlen > kBlockM) { if (params.seqused) { @@ -617,11 +607,7 @@ class VarlenDynamicPersistentTileScheduler { } else if constexpr(Prepared) { return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; } else { - return is_valid - ? (params.num_splits_dynamic_ptr // use conditional for sm8x kernels - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor) - : 0; + return is_valid ? params.nsplits_divmod.divisor : 0; } }; @@ -688,8 +674,7 @@ class VarlenDynamicPersistentTileScheduler { if constexpr(Prepared) { return params.num_nheads_in_l2_ptr[batch_idx]; } else { - // use conditional for sm8x kernels - return params.num_nheads_in_l2_ptr ? params.num_nheads_in_l2_ptr[batch_idx] : (!PackGQA ? params.qhead_per_khead : 1); + return !PackGQA ? params.qhead_per_khead : 1; } }; int nheads_in_l2 = get_nheads_in_l2(bidb); From 5df6e075369f3e7023a122be17314fb075c21219 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Wed, 20 Aug 2025 16:11:18 -0700 Subject: [PATCH 19/20] do inter-batch sort per every 992 batches --- hopper/flash_prepare_scheduler.cu | 146 ++++++++++++++---------------- 1 file changed, 68 insertions(+), 78 deletions(-) diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index a5caa607cf5..1d810c015ed 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -121,8 +121,8 @@ __global__ void prepare_varlen_num_blocks_kernel( }; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int batch_offset = int(blockIdx.x) * 992; - int bidb_start = batch_offset + kNumBatchPerWarp * warp_idx; + int batch_cta_idx_offset = int(blockIdx.x) * 992; + int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; int batch_idx = lane + bidb_start; int num_m_blocks = get_num_m_blocks(batch_idx); int num_n_blocks = get_num_n_blocks(batch_idx); @@ -136,92 +136,82 @@ __global__ void prepare_varlen_num_blocks_kernel( if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } return min(nheads_in_l2, num_head); }; - - if (blockIdx.x > 0) { - // trivially handle excess batches - if (batch_idx < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[batch_idx] = 1; - num_m_blocks_ptr[batch_idx] = num_m_blocks; - if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } - if constexpr (Sort) { - varlen_batch_idx_ptr[batch_idx] = batch_idx; - } - } + + int num_splits_dynamic; + if (int(gridDim.x) > 1 || num_splits_static == 1) { + // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) + // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) + num_splits_dynamic = 1; } else { - int num_splits_dynamic; - if (int(gridDim.x) > 1 || num_splits_static == 1) { - // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) - // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) - num_splits_dynamic = 1; - } else { - int total_blocks = num_m_blocks * num_n_blocks; - // Warp sum - #pragma unroll - for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { - total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); - } - if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } - __syncthreads(); - total_blocks = total_blocks_smem[0]; - // 10% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); - // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - // num_n_blocks per work tile for the batch - num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); + } - if constexpr (Sort) { - if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { - num_n_blocks = INT_MIN; // sort last - } else if (is_causal) { - // sort by shortest member to process - num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; - } - int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread - batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + if constexpr (Sort) { + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + } + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); - // if (threadIdx.x == 0) { - // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", - // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); - // } __syncthreads(); + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); - // Sort batches by num_n_blocks in descending order - BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); - // if (threadIdx.x == 0) { - // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", - // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); - // } __syncthreads(); + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); - if (is_causal) { - // reset value to num_n_blocks - batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); - } + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } - // When sorting, we re-index some metadata by 'virtual batch index' - // and also store the vbidx -> bidx mapping. - // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] - // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] - // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] - // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx - if (threadIdx.x < min(num_batch, 992)) { - // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); - if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[threadIdx.x] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } - num_m_blocks_ptr[threadIdx.x] = batch_coords[0].y; - num_splits_dynamic_ptr[threadIdx.x] = batch_coords[0].z; - varlen_batch_idx_ptr[threadIdx.x] = batch_coords[0].w; - } - } else { - if (batch_idx < num_batch && lane < kNumBatchPerWarp) { - // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); - if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } - num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; - num_m_blocks_ptr[batch_idx] = num_m_blocks; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); - } + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx + batch_idx = batch_cta_idx_offset + threadIdx.x; + if (batch_idx < num_batch && threadIdx.x < 992) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[batch_idx] = batch_coords[0].y; + num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; + varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; + } + } else { + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } + } } // flash From 3024f683c465336baa231ef467972bfe45d5b672 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Thu, 21 Aug 2025 19:39:03 -0700 Subject: [PATCH 20/20] better names in combine and fix prepare condition in api --- hopper/flash_api.cpp | 2 +- hopper/flash_fwd_combine_kernel.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 8fc930c5da6..8ffd0d0baf9 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -643,7 +643,7 @@ mha_fwd_get_scheduler_metadata( } } - if (params.num_splits_dynamic_ptr) { + if (use_prepare_varlen) { auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 81370c731c7..05667698006 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -207,9 +207,9 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const virtual_batch_idx = blockIdx.z; - int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[virtual_batch_idx] : virtual_batch_idx; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[virtual_batch_idx] : get<1>(params.shape_LSE_partial); + int const maybe_virtual_batch = blockIdx.z; + int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids();