diff --git a/csrc/fmha_v2/convert.cu b/csrc/fmha_v2/convert.cu deleted file mode 100644 index 345bd008f9..0000000000 --- a/csrc/fmha_v2/convert.cu +++ /dev/null @@ -1,196 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { - // The step. - size_t step = (size_t)gridDim.x * blockDim.x; - - // Iterate over the elements. - for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { - // Load 4 integers. - int4 tmp = reinterpret_cast(src)[ii]; - - // Convert to float and scale. - float x = static_cast(tmp.x) * scale; - float y = static_cast(tmp.y) * scale; - float z = static_cast(tmp.z) * scale; - float w = static_cast(tmp.w) * scale; - - // Convert to int8. - uint32_t a; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); - uint32_t b; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); - uint32_t c; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); - uint32_t d; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); - - // Compact. - char4 out; - out.x = reinterpret_cast(a); - out.y = reinterpret_cast(b); - out.z = reinterpret_cast(c); - out.w = reinterpret_cast(d); - - // Store. - reinterpret_cast(dst)[ii] = reinterpret_cast(out); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, - float scale) { - size_t n = (size_t)s * b * h * d; - convert_int32_to_int8_kernel<<<512, 256>>>(dst, src, n, scale); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline typename fmha::Uint_from_size_in_bytes::Type pack_float4( - float4 const& f); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -__device__ inline uint2 pack_float4(float4 const& f) { - return fmha::float4_to_half4(f.x, f.y, f.z, f.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -__device__ inline uint2 pack_float4(float4 const& f) { - return fmha::float4_to_16bit_x4(f.x, f.y, f.z, f.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -__device__ inline uint32_t pack_float4(float4 const& f) { - return fmha::float4_to_e4m3x4(f.x, f.y, f.z, f.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -__device__ inline uint32_t pack_float4(float4 const& f) { - return fmha::float4_to_e5m2x4(f.x, f.y, f.z, f.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { - using Dst = typename fmha::Uint_from_size_in_bytes::Type; - - // The step. - size_t step = (size_t)gridDim.x * blockDim.x; - - // Iterate over the elements. - for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { - // Load 4 floats. - float4 tmp = reinterpret_cast(src)[ii]; - // Scale. - tmp.x *= scale; - tmp.y *= scale; - tmp.z *= scale; - tmp.w *= scale; - // Convert to 4 Ts. - auto out = pack_float4(tmp); - - // Store. - reinterpret_cast(dst)[ii] = reinterpret_cast(out); - } -} - -template -__global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { - using Src = typename fmha::Uint_from_size_in_bytes::Type; - - union { - Src raw; - T elt[4]; - } data; - - // The step. - size_t step = (size_t)gridDim.x * blockDim.x; - - // Iterate over the elements. - for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { - // Load 4 floats. - data.raw = reinterpret_cast(src)[ii]; - float4 out; - // Scale. - out.x = float(data.elt[0]) * scale; - out.y = float(data.elt[1]) * scale; - out.z = float(data.elt[2]) * scale; - out.w = float(data.elt[3]) * scale; - - // Store. - reinterpret_cast(dst)[ii] = reinterpret_cast(out); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d) { - // No need to expose the scale factor for FP16/FP32. - size_t n = (size_t)s * b * h * d; - convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, 1.f); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d) { - // No need to expose the scale factor for FP16/FP32. - size_t n = (size_t)s * b * h * d; - convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, 1.f); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_e4m3(void* dst, void const* src, size_t n, float scale_o) { - convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, scale_o); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_e4m3_to_fp32(void* dst, void const* src, size_t n, float scale_o) { - convert_T_to_fp32_kernel<<<512, 256>>>(dst, src, n, scale_o); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, - float scale_o) { - run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_e5m2(void* dst, void const* src, size_t n, float scale_o) { - convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, scale_o); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_e5m2_to_fp32(void* dst, void const* src, size_t n, float scale_o) { - convert_T_to_fp32_kernel<<<512, 256>>>(dst, src, n, scale_o); -} diff --git a/csrc/fmha_v2/fmha/gmem_tile_o_packed.h b/csrc/fmha_v2/fmha/gmem_tile_o_packed.h index dc13b37f19..be5fbd8e45 100644 --- a/csrc/fmha_v2/fmha/gmem_tile_o_packed.h +++ b/csrc/fmha_v2/fmha/gmem_tile_o_packed.h @@ -114,7 +114,7 @@ struct Hmma_gmem_tile_o { // // row_offset += binfo.bidx * VALID_BYTES_PER_ROW; // - row_offset += binfo.bidx * valid_bytes_per_row; + row_offset += (int64_t)binfo.bidx * valid_bytes_per_row; // Assemble the final pointer. o_ptr_ += row_offset + col_in_bytes_; @@ -753,7 +753,7 @@ struct Gmem_tile_o_8bit { // The amount of bytes per row without padding (runtime). int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; // Take the batch/head offset into account. - row_offset += block_info.bidx * valid_bytes_per_row; + row_offset += (int64_t)block_info.bidx * valid_bytes_per_row; // Assemble the final pointer. o_ptr_ += row_offset + col_in_bytes_; @@ -1088,7 +1088,7 @@ struct Gmem_tile_o_16bit { // The amount of bytes per row without padding (runtime). int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; // Take the batch/head offset into account. - row_offset += block_info.bidx * valid_bytes_per_row; + row_offset += (int64_t)block_info.bidx * valid_bytes_per_row; // Assemble the final pointer. o_ptr_ += row_offset + col_in_bytes_; diff --git a/csrc/fmha_v2/fmha/gmem_tile_ps.h b/csrc/fmha_v2/fmha/gmem_tile_ps.h index de150ff293..a323a04447 100644 --- a/csrc/fmha_v2/fmha/gmem_tile_ps.h +++ b/csrc/fmha_v2/fmha/gmem_tile_ps.h @@ -558,7 +558,7 @@ struct Gmem_tile_ps { int col = warp / Cta_tile::WARPS_M * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG; // The offset of the 1st row written by the thread. We store the P matrix interleaved. - int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW; + int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + (int64_t)bidx * BYTES_PER_ROW; // Finalize the pointer. ptr_ += row_offset + col * BYTES_PER_ELEMENT; } @@ -654,7 +654,7 @@ struct Gmem_tile_ps { // The offset of the 1st row written by the thread. We store the P matrix interleaved. int64_t row_offset = - (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW + cta_row_offset; + (int64_t)row * params_stride_in_bytes_ + (int64_t)bidx * BYTES_PER_ROW + cta_row_offset; // Finalize the pointer. ptr_ += row_offset + col * BYTES_PER_ELEMENT; @@ -760,7 +760,7 @@ struct Gmem_tile_ps_hopper { int col = warpgroup_idx * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG; // The offset of the 1st row written by the thread. We store the P matrix interleaved. - int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * bytes_per_row; + int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + (int64_t)bidx * bytes_per_row; // Finalize the pointer. ptr_ += row_offset + col * BYTES_PER_ELEMENT; } diff --git a/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h b/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h index 7c9ac43bb8..14db4b8f50 100644 --- a/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h +++ b/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h @@ -96,7 +96,7 @@ struct Gmem_tile_o_hopper_16bits { // The offset of the 1st row written by the thread. We store the P matrix interleaved. int64_t row_offset = - (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + (int64_t)row_ * params_o_stride_in_bytes_ + (int64_t)block_info.bidx * BYTES_PER_ROW; // Finalize the pointer. o_ptr_ += row_offset + col * BYTES_PER_ELEMENT; } @@ -599,7 +599,7 @@ struct Gmem_tile_o_gmma_32bit_8bit { // The offset of the 1st row written by the thread. We store the P matrix interleaved. int64_t row_offset = - (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + (int64_t)row_ * params_o_stride_in_bytes_ + (int64_t)block_info.bidx * BYTES_PER_ROW; // Finalize the pointer. o_ptr_ += row_offset + col_ * BYTES_PER_ELEMENT; } @@ -1065,7 +1065,7 @@ struct Gmem_tile_o_qgmma_fp32_16bits { // The offset of the 1st row written by the thread. We store the P matrix interleaved. int64_t row_offset = - (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + (int64_t)row_ * params_o_stride_in_bytes_ + (int64_t)block_info.bidx * BYTES_PER_ROW; // Finalize the pointer. o_ptr_ += row_offset + col * BYTES_PER_ELEMENT; } diff --git a/csrc/fmha_v2/fmha/kernel_traits.h b/csrc/fmha_v2/fmha/kernel_traits.h index 8e1d5cbb22..f91240aca2 100644 --- a/csrc/fmha_v2/fmha/kernel_traits.h +++ b/csrc/fmha_v2/fmha/kernel_traits.h @@ -195,13 +195,14 @@ struct Kernel_traits_ { // Compute the total BMM2_MMAS_K (might not the same as Mma_tile_o::MMAS_K if the granular tiling // is used). - static_assert(S % CTA_O_TILE_K == 0, ""); + // S=0 for flash attention (variable sequence length): tile counts are determined at runtime. + static_assert(S == 0 || S % CTA_O_TILE_K == 0, ""); - enum { TOTAL_BMM2_MMAS_K = Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) }; + enum { TOTAL_BMM2_MMAS_K = S == 0 ? 0 : Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) }; // Constraints on the K dimension. static_assert(Mma_tile_p::K_PER_MMA <= static_cast(D)); - static_assert(Mma_tile_o::K_PER_MMA <= S); + static_assert(S == 0 || Mma_tile_o::K_PER_MMA <= S); // The version. enum { VERSION = VERSION_ }; diff --git a/csrc/fmha_v2/fmha/warpspec/compute.h b/csrc/fmha_v2/fmha/warpspec/compute.h index 9aae70b2e7..941fdcf9d0 100644 --- a/csrc/fmha_v2/fmha/warpspec/compute.h +++ b/csrc/fmha_v2/fmha/warpspec/compute.h @@ -179,7 +179,7 @@ struct Compute { USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \ : (q_step_idx * STEP_Q + head_info.q_tile_offset), \ kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \ - kv_step_idx == kv_idx_end - 1); + &shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1); //////////////////////////////////////////////////////////////////////////////////////////////// @@ -277,6 +277,12 @@ struct Compute { int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen; + // Update threshold of Skip-Softmax + if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) { + softmax.skip_softmax_threshold = + params.skip_softmax_threshold_scale_factor / actual_kv_seqlen; + } + // Calculate the alibi head_scaling_factor. float alibi_head_scale = APPLY_ALIBI ? get_alibi_head_scaling_factor( head_info.bidh, params.alibi_params) @@ -411,6 +417,12 @@ struct Compute { } } } +#ifdef SKIP_SOFTMAX_STAT + if (tidx == 0) { + atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks); + atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks); + } +#endif } //////////////////////////////////////////////////////////////////////////////////////////////// @@ -421,7 +433,14 @@ struct Compute { float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M], int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset, int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, - Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, bool complete = false) { + Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, + bool complete = false) { + // Skip-softmax vote initialization + if (tidx == 0) { + // Note that we need a named_barrier_wait in compute_single_tile to make sure init is before + // voting. + *skip_softmax_vote = 1; + } // load the scales of K/V from global memory #define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \ if constexpr (block_size > 0) { \ @@ -453,6 +472,10 @@ struct Compute { // Ctile_p is only used once by each n step. ctile_p.clear(); + // If skip_softmax is enabled, make sure there is no racing between the initialization and + // writing of skip_softmax_vote. + named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128); + // BMM1 (Q x K'). warpgroup_arrive(); @@ -513,8 +536,27 @@ struct Compute { softmax.apply_alibi_and_mask(ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset); - // Softmax Exp, max/sum, and update scales. - softmax.compute_and_update_scale(p_max, p_sum); + // Softmax Exp, max/sum, and update scales. If returns false we skip the rest. + if (!softmax.compute_and_update_scale(p_max, p_sum, skip_softmax_vote)) { + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) { + // Notify another warpgroup to execute QGMMA. + mutex.named_bar_arrive(); + } + // Need to wait V, otherwise compute-sanitizer synccheck will fail. + int ready2 = cbr_v.peek(); + if (!ready2) { + cbr_v.wait(); + } + +#pragma unroll + // Advance V descriptor by the same amount as the BMM2 loop would, + // so that the descriptor stays in sync for subsequent KV steps. + for (int kbi = 0; kbi < BMM2_MMAS_K_GROUPS - 1; kbi++) { + ctile_o.increment_gmma_desc_group(); + } + + return; + } // experiments show that here is the best place to load scales of V float scales_v[SAGE_BLOCKS_PER_STEP_V]; diff --git a/csrc/fmha_v2/fmha/warpspec/dma.h b/csrc/fmha_v2/fmha/warpspec/dma.h index a14ccafdf3..6934087270 100644 --- a/csrc/fmha_v2/fmha/warpspec/dma.h +++ b/csrc/fmha_v2/fmha/warpspec/dma.h @@ -137,8 +137,11 @@ struct DMA { } // Early stop when causal mask is enabled. + // q_step_end is an *inclusive* upper bound, so the tile that contains it + // is q_step_end / STEP_KV. We need kv_idx_end (exclusive) to be one + // past that, i.e. q_step_end / STEP_KV + 1. if (SKIP_CAUSAL_MASK_TILES) { - kv_idx_end = (q_step_end + STEP_KV - 1) / STEP_KV; + kv_idx_end = q_step_end / STEP_KV + 1; } return std::make_pair(kv_idx_start, kv_idx_end); diff --git a/csrc/fmha_v2/fmha/warpspec/epilogue.h b/csrc/fmha_v2/fmha/warpspec/epilogue.h index 15f8636207..40248a51cc 100644 --- a/csrc/fmha_v2/fmha/warpspec/epilogue.h +++ b/csrc/fmha_v2/fmha/warpspec/epilogue.h @@ -16,6 +16,8 @@ #include #include +#include "fmha/hopper/arrive_wait.h" + namespace fmha { namespace ws { @@ -71,6 +73,9 @@ struct Softmax_base { // Whether we need to check if local_max could be -inf or not. enum { CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK }; + // There are 2 warpgroups so 0x3 and 0x4 are used + enum { SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID }; + // Ctor. template inline __device__ Softmax_base(Params params, int tidx) @@ -80,7 +85,12 @@ struct Softmax_base { sliding_window_size_(params.sliding_window_size), log2_chunked_attention_size_(params.log2_chunked_attention_size), packed_mask_ptr_{reinterpret_cast(params.packed_mask_ptr)}, - params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes} { + params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}, +#ifdef SKIP_SOFTMAX_STAT + total_blocks(0), + skipped_blocks(0), +#endif + skip_softmax_threshold(0) { int warp = tidx / 32; int lane = tidx % 32; // The corresponding row/col for each thread after MMA. @@ -253,25 +263,67 @@ struct Softmax_base { } // Calculate max/sum, and update flash-attention scales. + // Returns false if skipped due to skip-softmax attention feature. template - inline __device__ void compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M], - float (&global_sum)[Mma_tile_p::CORES_M]) { + inline __device__ bool compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M], + float (&global_sum)[Mma_tile_p::CORES_M], + uint32_t* skip_softmax_vote) { float const scale = reinterpret_cast(scale_bmm1_); + // whether this warpgroup skips the softmax + constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL; + bool skip = may_skip; + // Row-wise max of current tile. #pragma unroll for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++) { - if (IS_FIRST_COL) { - local_max_[mi] = elt_[mi][0]; - } else { - local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]); - } + local_max_[mi] = elt_[mi][0]; #pragma unroll for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++) { local_max_[mi] = fmaxf(local_max_[mi], elt_[mi][ni]); } local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]); local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]); + + if constexpr (may_skip) { + // AND(&) the CORES_M results, then `skip` means whether to skip + // the CORES_M(=2) rows + if constexpr (!EXP2F_OPTIMIZATION) { + skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold; + } else { + skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold; + } + } + + if (!IS_FIRST_COL) { + local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]); + } + } + + if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) { +#ifdef SKIP_SOFTMAX_STAT + total_blocks++; +#endif + if constexpr (may_skip) { + // AND(&) the results together in a warp, then `skip` means whether to skip + // all the 16 rows managed by this warp. + // each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed + // instead of 0xffffffff. But the perf is the same. + skip = __all_sync(0xffffffff, skip); + if (threadIdx.x % 32 == 0) { + // The leader of each warp votes. + atomicAnd(skip_softmax_vote, uint32_t(skip)); + } + // WG0 uses 0x3 barrier, WG1 uses 0x4 barrier + named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128); + skip = *((uint32_t volatile*)skip_softmax_vote); + if (skip) { +#ifdef SKIP_SOFTMAX_STAT + skipped_blocks++; +#endif + return false; + } + } } // Softmax Exp. @@ -339,6 +391,7 @@ struct Softmax_base { global_max[mi] = max_new; } } + return true; } // Update flash attention scales and pack elements for BMM2. @@ -407,6 +460,13 @@ struct Softmax_base { float correction_[Mma_tile_p::CORES_M]; // The packed mask. uint4 packed_mask_; + // Skip softmax when exp(local_max - global_max) < skip_softmax_threshold. + float skip_softmax_threshold; +#ifdef SKIP_SOFTMAX_STAT + // Statistics of skip-softmax + uint32_t total_blocks; + uint32_t skipped_blocks; +#endif }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -676,29 +736,72 @@ struct Softmax inline __device__ Softmax(Params const& params, int tidx) : Base(params, tidx) {} // Calculate max/sum, and update flash-attention scales. + // Returns false if skipped due to skip-softmax attention feature. template - inline __device__ void compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M], - float (&global_sum)[Mma_tile_p::CORES_M]) { + inline __device__ bool compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M], + float (&global_sum)[Mma_tile_p::CORES_M], + uint32_t* skip_softmax_vote) { float const scale = reinterpret_cast(this->scale_bmm1_); float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_; float(&local_sum_)[Mma_tile_p::CORES_M] = this->local_sum_; float(&correction_)[Mma_tile_p::CORES_M] = this->correction_; float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_; + // whether this warpgroup skips the softmax + constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL; + bool skip = may_skip; + // Row-wise max of current tile. #pragma unroll for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++) { - if (IS_FIRST_COL) { - local_max_[mi] = elt_[mi][0]; - } else { - local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]); - } + local_max_[mi] = elt_[mi][0]; #pragma unroll for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++) { local_max_[mi] = fmaxf(local_max_[mi], elt_[mi][ni]); } local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]); local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]); + // AND(&) the CORES_M results, then `skip` means whether to skip + // the CORES_M(=2) rows + if constexpr (may_skip) { + // AND(&) the CORES_M results, then `skip` means whether to skip + // the CORES_M(=2) rows + if constexpr (!EXP2F_OPTIMIZATION) { + skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold; + } else { + skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold; + } + } + if (!IS_FIRST_COL) { + local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]); + } + } + + if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) { +#ifdef SKIP_SOFTMAX_STAT + this->total_blocks++; +#endif + + if constexpr (may_skip) { + // AND(&) the results together in a warp, then `skip` means whether to skip + // all the 16 rows managed by this warp. + // each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed + // instead of 0xffffffff. But the perf is the same. + skip = __all_sync(0xffffffff, skip); + if (threadIdx.x % 32 == 0) { + // The leader of each warp votes. + atomicAnd(skip_softmax_vote, uint32_t(skip)); + } + // WG0 uses 0x3 barrier, WG1 uses 0x4 barrier + named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128); + skip = *((uint32_t volatile*)skip_softmax_vote); + if (skip) { +#ifdef SKIP_SOFTMAX_STAT + this->skipped_blocks++; +#endif + return false; + } + } } // Softmax Exp. @@ -774,6 +877,7 @@ struct Softmax global_max[mi] = max_new; } } + return true; } // Update flash attention scales and pack elements for BMM2. diff --git a/csrc/fmha_v2/fmha/warpspec/kernel_traits.h b/csrc/fmha_v2/fmha/warpspec/kernel_traits.h index 6d96968c61..f1a8ce35bc 100644 --- a/csrc/fmha_v2/fmha/warpspec/kernel_traits.h +++ b/csrc/fmha_v2/fmha/warpspec/kernel_traits.h @@ -65,6 +65,8 @@ template < bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false, // Save softmax stats ? bool RETURN_SOFTMAX_STATS_ = false, + // Enable skip softmax attention feature + bool ENABLE_SKIP_SOFTMAX_ = false, // The output type (only used by fp8 kernels). typename OutputType = typename Instruction_traits::A_type, // The sage attention block size for Q, K and V @@ -189,6 +191,9 @@ struct Kernel_traits { // Use the custom mask input ( attention_mask_type == 3.) enum { USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3 }; + // Are we enabling skip softmax attention feature? + enum { ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_ }; + static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256, "Not implemented!"); @@ -250,6 +255,8 @@ struct Kernel_traits { // Named barrier ids static constexpr int DMA_SYNC_BARRIER_ID = 0x1; static constexpr int MMA_SYNC_BARRIER_ID = 0x2; + // There are 2 warpgroups so 0x3 and 0x4 are used for skip-softmax + static constexpr int SKIP_SOFTMAX_BARRIER_ID = 0x3; // How many threads get involved in the dma group. enum { NUM_THREADS_IN_DMA_GROUP = DMA_GROUP_TRANSPOSE_V ? 128 : (PAGED_KV_INPUT ? 1 : 32) }; @@ -383,6 +390,11 @@ struct Kernel_traits { // Mutex OrderedMutex compute_mutex; + // 4 warps in a warpgroup vote to an atomic variable in shared memory + // to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive + // KV_STEPS. + uint32_t skip_softmax_votes[2][NUM_COMPUTE_GROUPS]; + inline __device__ void init(int tid0) { #pragma unroll for (int i = 0; i < NUM_COMPUTE_GROUPS; i++) { @@ -439,24 +451,27 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2). bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false, // Save softmax stats ? bool RETURN_SOFTMAX_STATS_ = false, + // Enable skip softmax attention feature + bool ENABLE_SKIP_SOFTMAX_ = false, // The output type (only used by fp8 kernels). typename OutputType = e4m3_t, // The sage attention block size for Q, K and V int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0> struct Kernel_traits_Hopper_qgmma_e4m3_fp32 - : public Kernel_traits { + : public Kernel_traits< + Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_, + NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, + APPLY_ALIBI_, ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, + ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType, + SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_> { // Base class. - using Base = Kernel_traits; + using Base = + Kernel_traits; enum { USE_TMA_STORE = USE_TMA_STORE_ }; @@ -549,6 +564,11 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32 // Mutex OrderedMutex compute_mutex; + // 4 warps in a warpgroup vote to an atomic variable in shared memory + // to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive + // STEP_KVs. + uint32_t skip_softmax_votes[2][Base::NUM_COMPUTE_GROUPS]; + inline __device__ void init(int tid0) { #pragma unroll for (int i = 0; i < Base::NUM_COMPUTE_GROUPS; i++) { diff --git a/csrc/fmha_v2/fused_multihead_attention.cpp b/csrc/fmha_v2/fused_multihead_attention.cpp deleted file mode 100644 index c6ebbdcfa1..0000000000 --- a/csrc/fmha_v2/fused_multihead_attention.cpp +++ /dev/null @@ -1,1982 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -using Launch_params = bert::Fused_multihead_attention_launch_params; -using Attention_mask_type = fmha::Attention_mask_type; -using Attention_input_layout = fmha::Attention_input_layout; -using Kv_block_array = fmha::Kv_block_array; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, - int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, - float scale); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, - float scale_o); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int head_size, - unsigned int max_seq_len, - // device var - void const* q, void const* k, void const* v, int stride_q, int stride_k, - int stride_v, int const* cu_seqlens_q, int const* cu_seqlens_kv, - int block_size_q, int block_size_k, int block_size_v, - // output - void* quant_q, void* quant_k, void* quant_v, float* scales_q, float* scales_k, - float* scales_v); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, - float const scale_bmm1, float const scale_softmax, float const scale_bmm2, - float const softcapping_scale_bmm1, void* qkv_d, void* vt_d, void* mask_d, - void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d, - void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, - const size_t h, const size_t d, const size_t dv, int const runs, - int const warps_m, int const warps_n, bool const has_alibi) { - cudaStream_t stream = 0; - // The stride between rows of the QKV matrix. - size_t qkv_stride = get_size_in_bytes(d, data_type); - - // 1st GEMMd. - uint32_t alpha, beta = 0u; - - for (int ii = 0; ii < runs; ++ii) { - // If we run the INT8 kernel, defer the scaling of P to softmax. - set_alpha(alpha, data_type == DATA_TYPE_INT8 ? 1.f : scale_bmm1, acc_type); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // P = Q x K' - bmm1(static_cast(qkv_d) + 0 * qkv_stride, static_cast(qkv_d) + 1 * qkv_stride, - p_d, &alpha, &beta, stream); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // Softmax. - printf("Running softmax\n"); - if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, - h, softcapping_scale_bmm1, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32) { - run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, - h, softcapping_scale_bmm1, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, - h, softcapping_scale_bmm1, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, - h, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, - h, scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); - } else { - assert(false && "Reference Softmax: Unsupported type config"); - } - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // 2nd GEMM. - set_alpha(alpha, 1.f, acc_type); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - void* out_d = o_d; - - // We may have to do a final conversion. - if (data_type != acc_type) { - out_d = tmp_d; - } - // O = S x V - bmm2(static_cast(s_d), - static_cast(vt_d), // static_cast(qkv_d) + 2 * qkv_stride, - out_d, &alpha, &beta, stream); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // Conversion to output type. - if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - // Noop. - } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_conversion_fp32_to_fp16(o_d, out_d, s, b, h, dv); - } else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32) { - run_conversion_fp32_to_bf16(o_d, out_d, s, b, h, dv); - } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_conversion_fp32_to_e4m3(o_d, out_d, s, b, h, dv, scale_bmm2); - } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - // quantize output in second step - run_conversion_int32_to_int8(o_d, out_d, s, b, h, dv, scale_bmm2); - } - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline void set_params(bert::Fused_multihead_attention_params_v1& params, - // types - Data_type data_type, Data_type acc_type, - // sizes - const size_t b, const size_t s, const size_t h, const size_t d, - const size_t packed_mask_stride, - // device pointers - void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d, - // scale factors - float const scale_bmm1, float const scale_softmax, - float const scale_bmm2, - // flags - bool const has_alibi) { - memset(¶ms, 0, sizeof(params)); - - // Set the pointers. - params.qkv_ptr = qkv_d; - params.qkv_stride_in_bytes = get_size_in_bytes(b * h * 3 * d, data_type); - // params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); - params.packed_mask_ptr = packed_mask_d; - // params.packed_mask_stride_in_bytes = mmas_m * threads_per_cta * sizeof(uint32_t); - params.packed_mask_stride_in_bytes = packed_mask_stride * sizeof(uint32_t); - params.o_ptr = o_d; - params.o_stride_in_bytes = get_size_in_bytes(b * h * d, data_type); - params.has_alibi = has_alibi; - params.alibi_params = fmha::AlibiParams(h); - -#if defined(STORE_P) - params.p_ptr = p_d; - params.p_stride_in_bytes = get_size_in_bytes(b * h * s, acc_type); -#endif // defined(STORE_P) - -#if defined(STORE_S) - params.s_ptr = s_d; - params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type); -#endif // defined(STORE_S) - - // Set the dimensions. - params.b = b; - params.h = h; - params.s = s; - params.d = d; - - // Set the different scale values. - Data_type scale_type1 = - (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? acc_type : DATA_TYPE_FP32; - Data_type scale_type2 = - (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? data_type : DATA_TYPE_FP32; - - set_alpha(params.scale_bmm1, scale_bmm1, scale_type1); - set_alpha(params.scale_softmax, scale_softmax, scale_type1); - set_alpha(params.scale_bmm2, scale_bmm2, scale_type2); - - // Do we enable the trick to replace I2F with FP math in the 2nd GEMM? - if (data_type == DATA_TYPE_INT8) { - params.enable_i2f_trick = -double(1 << 22) * double(scale_bmm2) <= -128.f && - double(1 << 22) * double(scale_bmm2) >= 127.f; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline void set_params(bert::Fused_multihead_attention_params_v2& params, - const Launch_params launch_params, - // types - Data_type data_type, Data_type acc_type, Data_type output_dtype, - // attention input layout - Attention_input_layout input_layout, - // sizes - const size_t b, const size_t s_q, const size_t s_kv, const size_t h, - const size_t h_kv, const size_t d, const size_t dv, - const size_t total, const size_t num_grouped_heads, - const size_t sliding_window_size, const size_t chunked_attention_size, - // paged kv cache block size. - const size_t tokens_per_block, - // device pointers - void* qkv_packed_d, - // contiguous q. - void* q_d, - // separate k. - void* k_d, - // separate v. - void* v_d, - // contiguous kv. - void* kv_d, - // start address of the paged kv pool. - void* paged_kv_pool_ptr, - // offsets for different blocks in terms of the start address. - int32_t* paged_block_offsets, - // mask input. - void* packed_mask_d, void* cu_mask_rows_d, - // attention sinks. - void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, - void* o_packed_d, void* p_d, void* s_d, void* softmax_stats_d, - void* scale_bmm2_d, - // scale factors - float const scale_bmm1, float const scale_softmax, - float const scale_bmm2, float const softcapping_scale_bmm1, - // flags - bool const use_int8_scale_max, bool const interleaved, - bool const is_s_padded, bool const has_alibi) { - memset(¶ms, 0, sizeof(params)); - - params.o_ptr = o_packed_d; - params.o_stride_in_bytes = get_size_in_bytes(h * dv, output_dtype); - - if (interleaved) { - params.q_stride_in_bytes = total; - params.o_stride_in_bytes = total; - } - - if (input_layout == Attention_input_layout::PACKED_QKV) { - // For grouped- or multi-query attention (h denotes num_q_heads; h' denotes h_kv): - // qkv_layout = [b, s, [q_hd, k_h'd, v_h'd]] - // qkv_stride = (h+2*h')d * bytes_per_elt - // Otherwise: - // qkv_layout = [b, s, 3, h, d] or [b, s, h, 3, d] - // qkv_stride = 3hd * bytes_per_elt - params.qkv_ptr = qkv_packed_d; - params.q_stride_in_bytes = params.k_stride_in_bytes = params.v_stride_in_bytes = - get_size_in_bytes(h * d + h_kv * d + h_kv * dv, data_type); - } else { - // Layout [B, S, H, D]. - params.q_ptr = q_d; - params.q_stride_in_bytes = get_size_in_bytes(h * d, data_type); - - if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) { - // Layout [B, S, 2, H, D]. - params.kv_ptr = kv_d; - params.k_stride_in_bytes = params.v_stride_in_bytes = - get_size_in_bytes(h_kv * (d + dv), data_type); - } else if (input_layout == Attention_input_layout::Q_PAGED_KV) { - int max_blocks_per_sequence = (s_kv + tokens_per_block - 1) / tokens_per_block; - params.paged_kv_cache = - Kv_block_array(b, max_blocks_per_sequence, tokens_per_block, - get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type), - paged_kv_pool_ptr); - params.paged_kv_cache.mBlockOffsets = paged_block_offsets; - params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type); - params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type); - } else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) { - // Layout [B, S, H_kv, D]. - params.k_ptr = k_d; - // Layout [B, S, H_kv, Dv]. - params.v_ptr = v_d; - params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type); - params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type); - } - } - - // Packed mask. - params.packed_mask_ptr = packed_mask_d; - // The N dimension has to be aligned. - params.packed_mask_stride_in_bytes = - (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8; - - // Attention sinks. - params.attention_sinks = reinterpret_cast(attention_sinks_d); - -#if defined(STORE_P) - params.p_ptr = p_d; - params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type); -#endif // defined(STORE_P) - -#if defined(STORE_S) - params.s_ptr = s_d; - params.s_stride_in_bytes = get_size_in_bytes(b * h * s_kv, data_type); -#endif // defined(STORE_S) - - params.softmax_stats_ptr = softmax_stats_d; - params.softmax_stats_stride_in_bytes = get_size_in_bytes(h * 2, DATA_TYPE_FP32); - - // Set the dimensions. - params.b = b; - params.h = h; - params.s = s_q; - params.d = d; - params.dv = dv; - params.num_grouped_heads = num_grouped_heads; - params.sliding_window_size = sliding_window_size; - assert((chunked_attention_size == 0 || - (chunked_attention_size & (chunked_attention_size - 1)) == 0) && - "chunked_attention_size has to be a power of 2"); - params.log2_chunked_attention_size = - chunked_attention_size > 0 ? std::log2(chunked_attention_size) : 0; - - // cumulative q or kv sequence lengths. - params.cu_q_seqlens = static_cast(cu_q_seqlens_d); - params.cu_kv_seqlens = static_cast(cu_kv_seqlens_d); - // cumulative mask sequence lengths. - params.cu_mask_rows = static_cast(cu_mask_rows_d); - - // Set the different scale values. - Data_type scale_type1 = - (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? acc_type : DATA_TYPE_FP32; - Data_type scale_softmax_type = scale_type1; - Data_type scale_type2 = - (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? data_type : DATA_TYPE_FP32; - if (data_type == DATA_TYPE_E4M3) { - scale_type1 = acc_type; - scale_type2 = acc_type; - } - - // Fuse 1.0f / softcapping_scale into scale_bmm1. - bool const enable_attn_logit_softcapping = softcapping_scale_bmm1 != 0.f; - float fused_scale_bmm1 = - enable_attn_logit_softcapping ? scale_bmm1 / softcapping_scale_bmm1 : scale_bmm1; - - // use specialized hopper kernels without alibi support. - // alibi or softcapping_scale cannot utilize the exp2f with fused_scale optimization. - if (launch_params.warp_specialization && !has_alibi && !enable_attn_logit_softcapping) { - set_alpha(params.scale_bmm1, fused_scale_bmm1 * float(M_LOG2E), DATA_TYPE_FP32); - } else { - set_alpha(params.scale_bmm1, fused_scale_bmm1, scale_type1); - } - set_alpha(params.scale_softmax, scale_softmax, scale_softmax_type); - set_alpha(params.scale_bmm2, scale_bmm2, scale_type2); - params.scale_bmm2_d = reinterpret_cast(scale_bmm2_d); - params.softcapping_scale_bmm1 = softcapping_scale_bmm1; - - FMHA_CHECK_CUDA(cudaMemcpy(params.scale_bmm2_d, ¶ms.scale_bmm2, sizeof(uint32_t), - cudaMemcpyHostToDevice)); - - // attention type, h_kv < h if MQA or GQA - params.h_kv = h_kv; - assert(h % h_kv == 0 && "MQA/GQA needs h to be divisible by h_kv!"); - params.h_q_per_kv = h / h_kv; - params.has_alibi = has_alibi; - params.alibi_params = fmha::AlibiParams(h); - - // Set flags - params.is_s_padded = is_s_padded; - params.use_int8_scale_max = use_int8_scale_max; - - // Do we enable the trick to replace I2F with FP math in the 2nd GEMM? - if (data_type == DATA_TYPE_INT8) { - params.enable_i2f_trick = -double(1 << 22) * double(scale_bmm2) <= -128.f && - double(1 << 22) * double(scale_bmm2) >= 127.f; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline void determine_launch_params( - Launch_params& launch_params, Data_type data_type, int sm, const size_t s, const size_t d, - const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout, - bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma, - bool const force_non_flash_attention, bool const force_non_warp_specialization, - bool const force_non_granular_tiling, bool const force_fp32_acc, - // device props - const cudaDeviceProp props) { - // Set launch params to choose kernels - launch_params.ignore_b1opt = ignore_b1opt; - launch_params.force_unroll = force_unroll; - launch_params.force_fp32_acc = force_fp32_acc; - launch_params.interleaved = interleaved; - launch_params.attention_mask_type = attention_mask_type; - launch_params.attention_input_layout = input_layout; - - // Set SM count and L2 cache size (used to determine launch blocks/grids to maximum performance) - launch_params.multi_processor_count = props.multiProcessorCount; - launch_params.device_l2_cache_size = props.l2CacheSize; - - // threshold for adopting flash attention or warp_specialized kernels. - launch_params.flash_attention = - (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && - (s >= 16 && d >= 16) && !force_non_flash_attention; - - // enable warp_speialized kernels when s >= 512 on hopper - // note that warp_speialized kernels need flash attention + tma - launch_params.warp_specialization = - (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && - sm == 90 && launch_params.flash_attention && !force_non_warp_specialization; - // warp specialization kernels on hopper need tma - launch_params.use_tma = use_tma || launch_params.warp_specialization; - - // use granular tiling on Ampere-style flash attention - launch_params.use_granular_tiling = !force_non_granular_tiling && launch_params.flash_attention && - !launch_params.warp_specialization && sm >= 80; - - if (launch_params.use_granular_tiling && (data_type == DATA_TYPE_E4M3 && sm == 80)) { - printf( - "Fallback to non-granular-tiling kernels as tiled e4m3 kernels" - "are not supported on Ada currently.\n"); - launch_params.use_granular_tiling = false; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -int main(int argc, char** argv) { - // The device. Reset on destruction - CudaDevice device; - int sm = device.sm; - cudaDeviceProp props = device.props; - - GpuTimer timer; - - // The batch size. - size_t b = 128; - // The number of heads. - size_t h = 16; - // The dimension of the Q, K and V vectors. - size_t d = 64; - // The dimension of V if set to non-zero, otherwise dimension of V equals to that of Q - size_t dv = 0; - // The length of the sequence. - size_t s = 384; - // Number of grouped heads in the seqlen dimension. - size_t num_grouped_heads = 1; - // Sliding Window Attention - // Only pay attention to [max(0, query_idx - sliding_window_size), query_idx]. - size_t sliding_window_size = size_t(INT_MAX); - // The chunked-attention size. - size_t chunked_attention_size = 0; - - // The data type of the kernel. - Data_type data_type = DATA_TYPE_FP16; - // The type of the intermediate P matrix. - Data_type acc_type = DATA_TYPE_FP16; - // The type of the output. - Data_type output_dtype = DATA_TYPE_FP16; - // Is the output type set ? - bool is_output_dtype_set = false; - - // The scaling factors. - float scale_bmm1 = 0.f, scale_softmax = 0.f, scale_bmm2 = 0.25f; - // The number of runs. - int runs = 1, warm_up_runs = 0; - // Do we use 1s for Q, K, V. - bool use_1s_q = false, use_1s_k = false, use_1s_v = false; - // The range of the different inputs. - int range_q = 5, range_k = 3, range_v = 5; - // The scale. - float scale_q = 0.f, scale_k = 0.f, scale_v = 0.f; - // The threshold for dropout. By default, drop 10%. - float dropout = 0.1f; - // Do we skip the checks. - bool skip_checks = false; - // The tolerance when checking results. - float epsilon = -1.f; // data_type == DATA_TYPE_FP16 ? 0.015f : 0.f; - // Use causal mask / padding_mask / sliding_or_chunked_causal mask / custom_mask input. - Attention_mask_type attention_mask_type = Attention_mask_type::PADDING; - // Use padded format for input QKV tensor & output O tensor. - // Instead of variable lengths [total, h, 3, d] where total = b1*s1 + b2*s2 + ... bn*sn, - // use padded length [b, max_s, h, 3, d] where max_s is the maximum expected seq len - bool is_s_padded = false; - - // minimum sequence length for sampling variable seqlens - uint32_t min_s = -1; - - // run interleaved kernels and transpose input and output accordingly - bool interleaved = false; - bool ignore_b1opt = false; - bool force_unroll = false; - // used by kernels that have different acc data types (like hmma, qmma) - bool force_fp32_acc = false; - bool force_non_flash_attention = false; - // enable warp specialization kernels on sm 90 - bool force_non_warp_specialization = (sm != 90); - bool use_int8_scale_max = false; - bool verbose = true; - bool save_softmax = false; - - // use granular tiling - // supported only by Ampere-based Flash Attention at this moment - bool force_non_granular_tiling = false; - - // set all sequence lengths to min(s, min_s) - bool fix_s = false; - - bool v1 = false; - - // use TMA or not. ignored if not in SM90 - bool use_tma = false; - - // use alibi. - bool has_alibi = false; - - // Use softcapping_scale_bmm1 (scale * __tanhf(x / scale)). - float softcapping_scale_bmm1 = 0.f; - - // In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV - // head - bool multi_query_attention = false; - size_t h_kv = 0; - - // The attention input layout. - Attention_input_layout input_layout = Attention_input_layout::PACKED_QKV; - - // TRTLLM uses 64 by default in paged kv cache. - size_t tokens_per_block = 64; - - // Attention that has different q and kv lengths. - size_t s_q = 0; - // different q and kv sequence lengths. - bool different_q_kv_lengths = false; - - // SageAttention block sizes - int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0; - - // Use attention sinks (added to the denominator of softmax) - bool use_attention_sinks = false; - - // Read the parameters from the command-line. - for (int ii = 1; ii < argc; ++ii) { - if (!strcmp(argv[ii], "-1s")) { - use_1s_k = use_1s_q = use_1s_v = true; - } else if (!strcmp(argv[ii], "-1s-k")) { - use_1s_k = true; - } else if (!strcmp(argv[ii], "-1s-q")) { - use_1s_q = true; - } else if (!strcmp(argv[ii], "-1s-v")) { - use_1s_v = true; - } else if (!strcmp(argv[ii], "-b") && ++ii < argc) { - b = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-d") && ++ii < argc) { - d = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-dv") && ++ii < argc) { - dv = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-s-q") && ++ii < argc) { - s_q = strtol(argv[ii], nullptr, 10); - different_q_kv_lengths = true; - } else if (!strcmp(argv[ii], "-dropout") && ++ii < argc) { - dropout = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-epsilon") && ++ii < argc) { - epsilon = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-h") && ++ii < argc) { - h = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-int8")) { - data_type = DATA_TYPE_INT8; - acc_type = DATA_TYPE_INT32; - } else if (!strcmp(argv[ii], "-fp16")) { - data_type = DATA_TYPE_FP16; - acc_type = DATA_TYPE_FP16; - } else if (!strcmp(argv[ii], "-fp16-fp32")) { - data_type = DATA_TYPE_FP16; - acc_type = DATA_TYPE_FP32; - force_fp32_acc = true; - } else if (!strcmp(argv[ii], "-bf16")) { - data_type = DATA_TYPE_BF16; - acc_type = DATA_TYPE_FP32; - force_fp32_acc = true; - } else if (!strcmp(argv[ii], "-e4m3")) { - data_type = DATA_TYPE_E4M3; - // Technically not the acc type. - acc_type = DATA_TYPE_FP32; - force_fp32_acc = true; - } else if (!strcmp(argv[ii], "-e4m3-fp16")) { // Ada QMMA only - data_type = DATA_TYPE_E4M3; - // Technically not the acc type. - acc_type = DATA_TYPE_FP16; - } else if (!strcmp(argv[ii], "-e4m3-fp32")) { - data_type = DATA_TYPE_E4M3; - // Technically not the acc type. - acc_type = DATA_TYPE_FP32; - force_fp32_acc = true; - } else if (!strcmp(argv[ii], "-fp16-output")) { - output_dtype = DATA_TYPE_FP16; - is_output_dtype_set = true; - } else if (!strcmp(argv[ii], "-bf16-output")) { - output_dtype = DATA_TYPE_BF16; - is_output_dtype_set = true; - } else if (!strcmp(argv[ii], "-num-grouped-heads") && ++ii < argc) { - num_grouped_heads = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-range-k") && ++ii < argc) { - range_k = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-range-q") && ++ii < argc) { - range_q = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-range-v") && ++ii < argc) { - range_v = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-runs") && ++ii < argc) { - runs = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-s") && ++ii < argc) { - s = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-sliding-window-size") && ++ii < argc) { - sliding_window_size = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-chunked-attention-size") && ++ii < argc) { - chunked_attention_size = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-scale-bmm1") && ++ii < argc) { - scale_bmm1 = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-bmm2") && ++ii < argc) { - scale_bmm2 = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-k") && ++ii < argc) { - scale_k = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-softmax") && ++ii < argc) { - scale_softmax = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-q") && ++ii < argc) { - scale_q = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-v") && ++ii < argc) { - scale_v = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-skip-checks")) { - skip_checks = true; - } else if (!strcmp(argv[ii], "-warm-up-runs") && ++ii < argc) { - warm_up_runs = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-min-s") && ++ii < argc) { - min_s = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-il")) { - interleaved = true; - } else if (!strcmp(argv[ii], "-causal-mask")) { - attention_mask_type = Attention_mask_type::CAUSAL; - } else if (!strcmp(argv[ii], "-sliding-or-chunked-causal-mask")) { - attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL; - } else if (!strcmp(argv[ii], "-custom-mask")) { - attention_mask_type = Attention_mask_type::CUSTOM_MASK; - } else if (!strcmp(argv[ii], "-multi-query-attention") || !strcmp(argv[ii], "-mqa")) { - h_kv = 1; - multi_query_attention = true; // subset of GQA - } else if ((!strcmp(argv[ii], "-grouped-query-attention") || !strcmp(argv[ii], "-gqa")) && - ++ii < argc) { - h_kv = strtol(argv[ii], nullptr, 10); - multi_query_attention = true; - } else if (!strcmp(argv[ii], "-contiguous-q-kv")) { - input_layout = Attention_input_layout::CONTIGUOUS_Q_KV; - } else if (!strcmp(argv[ii], "-paged-kv")) { - input_layout = Attention_input_layout::Q_PAGED_KV; - } else if (!strcmp(argv[ii], "-separate-q-k-v")) { - input_layout = Attention_input_layout::SEPARATE_Q_K_V; - } else if (!strcmp(argv[ii], "-tokens-per-block") && ++ii < argc) { - tokens_per_block = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-pad-s")) { - is_s_padded = true; - } else if (!strcmp(argv[ii], "-ignore-b1opt")) { - ignore_b1opt = true; - } else if (!strcmp(argv[ii], "-force-unroll")) { - force_unroll = true; - } else if (!strcmp(argv[ii], "-force-non-flash-attention")) { - force_non_flash_attention = true; - force_non_warp_specialization = true; - } else if (!strcmp(argv[ii], "-force-flash-attention")) { - fprintf(stderr, - "Deprecation warning: -force-flash-attention is no longer valid; use " - "-force-non-flash-attention instead, as Flash Attention is enabled by default.\n"); - } else if (!strcmp(argv[ii], "-force-non-warp-specialization")) { - force_non_warp_specialization = true; - } else if (!strcmp(argv[ii], "-force-non-granular-tiling") || - !strcmp(argv[ii], "-force-non-tiled")) { - force_non_granular_tiling = true; - } else if (!strcmp(argv[ii], "-fix-s")) { - fix_s = true; - } else if (!strcmp(argv[ii], "-scale-max")) { - use_int8_scale_max = true; - } else if (!strcmp(argv[ii], "-v") && ++ii < argc) { - int v = strtol(argv[ii], nullptr, 10); - verbose = v != 0; - } else if (!strcmp(argv[ii], "-v1")) { - v1 = true; - } else if (!strcmp(argv[ii], "-use-tma")) { - use_tma = true; - // flash attention + tma + non_warp_specialized kernels are not supported - // use non_flash_attention + tma + non_warp_specialized instead - if (force_non_warp_specialization) { - force_non_flash_attention = true; - } - } else if (!strcmp(argv[ii], "-alibi")) { - has_alibi = true; - } else if (!strcmp(argv[ii], "-softcapping-scale-bmm1") && ++ii < argc) { - softcapping_scale_bmm1 = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-save-softmax")) { - save_softmax = true; - } else if (!strcmp(argv[ii], "-sage-block-q") && ++ii < argc) { - sage_block_size_q = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-sage-block-k") && ++ii < argc) { - sage_block_size_k = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-sage-block-v") && ++ii < argc) { - sage_block_size_v = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-use-attention-sinks")) { - use_attention_sinks = true; - } else { - fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]); - return -1; - } - } - if (save_softmax == true) { - bool is_MLA = (d == 192 && dv == 128); - if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) || - (is_MLA && input_layout != Attention_input_layout::SEPARATE_Q_K_V)) { - fprintf(stderr, - "For normal attention, Only '--contiguous-q-kv' layout supports " - "'-save-softmax'. For MLA only '-separate-q-k-v' layout supports " - "'-save-softmax'.\n"); - exit(1); - } - } - // Sanitize - if (min_s == -1) min_s = s; - min_s = std::min(s, min_s); - h_kv = multi_query_attention ? h_kv : h; - - // Check if the options are valid. - if (different_q_kv_lengths) { - assert(input_layout != Attention_input_layout::PACKED_QKV && - "Packed QKV input layout is not supported with different q and kv lengths."); - assert(s >= s_q && "q seqlen has to be smaller than or equal to the kv seqlen !"); - } else { - s_q = s; - } - - // Sliding window attention (only pay attention to sliding-window-size long previous tokens). - if (sliding_window_size < s) { - assert(chunked_attention_size == 0 && - "chunked_attention_size should not be used when sliding_window_size is set"); - attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL; - } - // Chunked attention. - if (chunked_attention_size > 0) { - assert((chunked_attention_size & (chunked_attention_size - 1)) == 0 && - "chunked_attention_size has to be a power of 2"); - attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL; - } - - // Set the norm. - if (scale_bmm1 == 0.f) { - scale_bmm1 = 1.f / sqrtf((float)d); - } - - // Set the output type if not set by user. - if (!is_output_dtype_set) { - output_dtype = data_type; - } - - // Force the softmax scale to 1.f for the FP16 kernel. - if (data_type == DATA_TYPE_FP16) { - scale_softmax = 1.f; - } else if (data_type == DATA_TYPE_INT8 && scale_softmax == 0.f) { - scale_softmax = std::max(512.f, (float)s); - } else if (data_type == DATA_TYPE_E4M3 && scale_softmax == 0.f) { - scale_softmax = 1.f; // For E4M3 this is hardcoded as the largest power-of-2 below E4M3_MAX - } - - // Sage Attention uses the e4m3 data type - if (sage_block_size_q > 0 || sage_block_size_k > 0 || sage_block_size_v > 0) { - scale_softmax = 1.f; - scale_bmm2 = 1.f; - force_fp32_acc = true; - acc_type = DATA_TYPE_FP32; - } - - // Define the scaling factor for the different inputs. - if (scale_q == 0.f) { - scale_q = 1.f; - } - if (scale_k == 0.f) { - scale_k = 1.f; - } - if (scale_v == 0.f) { - // BF16 here just for debug. - scale_v = (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16) ? 0.125f : 1.f; - } - if (has_alibi && attention_mask_type == Attention_mask_type::PADDING) { - attention_mask_type = Attention_mask_type::CAUSAL; - } - - // BF16 only support FP32 acc_type. - if (data_type == DATA_TYPE_BF16 && acc_type != DATA_TYPE_FP32) { - fprintf(stderr, "Only FP32 accumulation is supported for BF16 I/O\n"); - exit(1); - } - - // Set the tolerance if not already set by the user. - if (epsilon < 0.f) { - switch (data_type) { - case DATA_TYPE_FP16: - epsilon = 0.015f; - break; - case DATA_TYPE_BF16: - epsilon = 0.025f; - break; - case DATA_TYPE_E4M3: - epsilon = 0.15f; - break; - default: - epsilon = 0.f; - } - // the accuracy of SageAttention may be between fp8 and fp16/bf16 ? - if (sage_block_size_q > 0 || sage_block_size_k > 0 || sage_block_size_v > 0) { - epsilon = 0.05f; - } - } - - // let the dimension of V equal to that of Q if not set by user - if (dv == 0) { - dv = d; - } - - // Debug info -- only in verbose mode. - if (verbose) { - // Running the following command. - printf("Command.......: %s", argv[0]); - for (int ii = 1; ii < argc; ++ii) { - printf(" %s", argv[ii]); - } - printf("\n"); - - // Device info. - printf("Device........: %s\n", props.name); - printf("Arch.(sm).....: %d\n", sm); - printf("#.of.SMs......: %d\n", props.multiProcessorCount); - - // Problem info. - printf("Batch ........: %lu\n", b); - printf("Heads ........: %lu\n", h); - printf("Dimension ....: %lu\n", d); - printf("Dimension of V ....: %lu\n", dv); - printf("Seq length ...: %lu\n", s); - printf("Warm-up runs .: %d\n", warm_up_runs); - printf("Runs..........: %d\n\n", runs); - - // The scaling factors for the 3 operations. - printf("Scale bmm1 ...: %.6f\n", scale_bmm1); - printf("Scale softmax.: %.6f\n", scale_softmax); - printf("Scale bmm2 ...: %.6f\n", scale_bmm2); - printf("\n"); - } - - // determine the launch params to select kernels - Launch_params launch_params; - determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, - interleaved, ignore_b1opt, force_unroll, use_tma, - force_non_flash_attention, force_non_warp_specialization, - force_non_granular_tiling, force_fp32_acc, props); - - // The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D. - const size_t qkv_size = s * b * h * (2 * d + dv); - // Allocate on the host. - float* qkv_h = (float*)malloc(qkv_size * sizeof(float)); - // The size in bytes. - const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); - // Allocate on the device. - void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes)); - FMHA_CHECK_CUDA(cudaMalloc(&qkv_bsh3d_d, qkv_size_in_bytes)); - - // Contiguous KV cache buffer. - // The shape is [B, 2, S, H, D]. - const size_t kv_size = b * s * h_kv * (d + dv); - // The size in bytes. - const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); - // Allocate on the host. - void* contiguous_kv_h = malloc(kv_size_in_bytes); - // Memset the buffer. - memset(contiguous_kv_h, 0, kv_size_in_bytes); - // Allocate on the device. - void* contiguous_kv_d; - FMHA_CHECK_CUDA(cudaMalloc(&contiguous_kv_d, kv_size_in_bytes)); - - // Paged KV Cache buffer. - // The shape is [B, 2, Blocks_per_sequence], and each block's buffer shape is [H, - // Tokens_per_block, Dh]. - void** kv_cache_ptrs_h = nullptr; - void* kv_cache_pool_ptr = nullptr; - int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr; - const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; - const size_t num_total_blocks = b * 2 * max_blocks_per_seq; - kv_cache_ptrs_h = (void**)malloc(num_total_blocks * sizeof(void*)); - kv_cache_block_offsets_h = (int32_t*)malloc(num_total_blocks * sizeof(int32_t)); - const size_t paged_kv_block_size_in_bytes = - get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); - FMHA_CHECK_CUDA( - cudaMalloc((void**)(&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t))); - const size_t kv_cache_pool_sz = - get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type); - FMHA_CHECK_CUDA(cudaMalloc((void**)(&kv_cache_pool_ptr), kv_cache_pool_sz)); - size_t ptr_index = 0; - size_t abs_offset = 0; - for (size_t bi = 0; bi < b; bi++) { - for (int kv_offset = 0; kv_offset < 2; kv_offset++) { - size_t block_size = - get_size_in_bytes(tokens_per_block * h_kv * (kv_offset == 0 ? d : dv), data_type); - for (size_t block_i = 0; block_i < max_blocks_per_seq; block_i++) { - kv_cache_ptrs_h[ptr_index] = - reinterpret_cast(reinterpret_cast(kv_cache_pool_ptr) + abs_offset); - assert(abs_offset % paged_kv_block_size_in_bytes == 0); - kv_cache_block_offsets_h[ptr_index] = abs_offset / paged_kv_block_size_in_bytes; - ptr_index++; - abs_offset += block_size; - } - } - } - assert(ptr_index == num_total_blocks && abs_offset == kv_cache_pool_sz); - FMHA_CHECK_CUDA(cudaMemcpy(kv_cache_block_offsets_d, kv_cache_block_offsets_h, - num_total_blocks * sizeof(int32_t), cudaMemcpyDefault)); - - // Q will always be [B, S, H, Dh] with paged kv cache. - void* q_d; - const size_t q_size = s * b * h * d; - FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type))); - - // K has [B, S, H_kv, D] with separate kv cache. - void* k_d; - const size_t k_size = s * b * h_kv * d; - FMHA_CHECK_CUDA(cudaMalloc(&k_d, get_size_in_bytes(k_size, data_type))); - - // V has [B, S, H_kv, Dv] with separate kv cache. - void* v_d; - const size_t v_size = s * b * h_kv * dv; - FMHA_CHECK_CUDA(cudaMalloc(&v_d, get_size_in_bytes(v_size, data_type))); - - // Scale bmm2 (per-tensor). - void* scale_bmm2_d; - FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t))); - - // The mask for dropout or any mask patterns. - const size_t mask_size = s * b * s; - // Allocate on the host. - float* mask_h = (float*)malloc(mask_size * sizeof(float)); - // The size in bytes. - const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); - // Allocate on the device. - void* mask_d = nullptr; - if (!skip_checks) { - FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes)); - } - - // The decomposition of threads and warps for BMM1. - size_t warps_m, warps_n, warps_k; - std::tie(warps_m, warps_n, warps_k) = - get_warps(launch_params, sm, data_type, s, b, d, v1 ? 1 : 2); - - // print launch configuration - printf( - "v1=%d il=%d s_q=%lu, s=%lu b=%lu h=%lu/%lu d=%lu/%lu dtype=%s, output_dtype=%s, " - "flash_attn=%s, " - "warp_spec=%s, mask=%s, " - "alibi=%s, attn=%s, qkv_layout=%s, wm=%lu wn=%lu\n", - v1, interleaved, s_q, s, b, h, h_kv, d, dv, data_type_to_name(data_type).c_str(), - data_type_to_name(output_dtype).c_str(), - launch_params.flash_attention ? (launch_params.use_granular_tiling ? "true_tiled" : "true") - : "false", - launch_params.warp_specialization ? "true" : "false", - mask_type_to_string(attention_mask_type).c_str(), has_alibi ? "true" : "false", - h_kv == 1 ? "mqa" : (h_kv == h ? "mha" : "gqa"), - attention_input_layout_to_string(input_layout).c_str(), warps_m, warps_n); - - // For multi-CTA cases, determine the size of the CTA wave. - int heads_per_wave, ctas_per_head; - get_grid_size(heads_per_wave, ctas_per_head, sm, data_type, b, s, h, d, - false, // disable multi-cta kernels by default - v1 ? 1 : 2); - - // The number of threads per CTA. - const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; - // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. - size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m); - // The number of mmas in the N dimension. - size_t mmas_n = (s + 16 * warps_n - 1) / (16 * warps_n); - // We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask). - assert(!v1 || mmas_n <= 4); - // The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA. - size_t packed_mask_size = b * mmas_m * threads_per_cta; - // Flash attention on Ampere and Hopper, which supports multiple mmas_n - if (!v1 && !force_non_flash_attention && - attention_mask_type == Attention_mask_type::CUSTOM_MASK) { - // We need to align q and k sequence lengths. - size_t rounded_q_s = align_to(s, size_t(fmha::FLASH_ATTEN_MASK_M_ALIGNMENT)); - size_t rounded_k_s = align_to(s, size_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT)); - // The number of mmas in the M dimension (MMA_M = 64). - mmas_m = rounded_q_s / fmha::FLASH_ATTEN_MASK_MMA_M; - // The number of mmas in the N dimension (MMA_N = 64). - mmas_n = rounded_k_s / fmha::FLASH_ATTEN_MASK_MMA_N; - // Each thread holds 32 bit (2 rows, 16 cols -> 8 core MMAs) in one MMA here. - packed_mask_size = b * mmas_m * mmas_n * threads_per_cta; - } - // The size in bytes. - const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); - // Allocate on the host. - uint32_t* packed_mask_h = (uint32_t*)malloc(packed_mask_size_in_bytes); - // Set it to 0 (indicates that all elements are valid). - memset(packed_mask_h, 0, packed_mask_size_in_bytes); - // Allocate on the device. - void* packed_mask_d = nullptr; - - // The size of the attention sinks. - const size_t attention_sinks_size_in_bytes = h * sizeof(float); - - // The attention sinks. - void* attention_sinks_d = nullptr; - if (use_attention_sinks) { - // Allocate on the host. - float* attention_sinks_h = (float*)malloc(attention_sinks_size_in_bytes); - // Randomly initialize the attention sinks. - random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose); - // Allocate on the device. - FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes)); - // Copy from the host to the device. - FMHA_CHECK_CUDA(cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, - cudaMemcpyDefault)); - } - - // The O matrix is packed as S * B * H * D. - const size_t o_size = s * b * h * dv; - // Allocate on the host. - float* o_h = (float*)malloc(o_size * sizeof(float)); - // The size in bytes. - const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); - // Allocate on the device. - void* o_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); - - // The softmax_stats_d vector is used to store the max/sum of the softmax per token - void* softmax_stats_d; - FMHA_CHECK_CUDA(cudaMalloc(&softmax_stats_d, 2 * sizeof(float) * b * s * h)); - FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h)); - - // The size in bytes. - const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); - // Allocate on the device. - void* tmp_d = nullptr; - if (data_type != acc_type) { - FMHA_CHECK_CUDA(cudaMalloc(&tmp_d, tmp_size_in_bytes)); - } - - // Allocate the reference on the host. - float* o_ref_h = (float*)malloc(o_size * sizeof(float)); - float* softmax_stats_ref_h = (float*)malloc(2 * b * s * h * sizeof(float)); - float* softmax_stats_h = (float*)malloc(2 * b * s * h * sizeof(float)); - - // The P matrix is stored as one big matrix of size S x B x H x S. - const size_t p_size = s * b * h * s; - // The size in bytes. - const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); - // Allocate on the device. - void* p_d = nullptr; - if (!skip_checks) { - FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes)); - } - - // Allocate the reference on the host. - float* p_ref_h = (float*)malloc(p_size * sizeof(float)); -#if defined(STORE_P) - // Allocate on the host. - float* p_h = (float*)malloc(p_size * sizeof(float)); -#endif // defined(STORE_P) - - // The size in bytes of the S matrix (the data type may be different from P for int8). - const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); - // Allocate on the device. - void* s_d = nullptr; - if (!skip_checks) { - FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes)); - } - - // Allocate the reference on the host. - float* s_ref_h = (float*)malloc(p_size * sizeof(float)); - - // Allocate on the host. - float* s_h = (float*)malloc(p_size * sizeof(float)); - // Make sure we set the seed for reproducible results. - srand(1234UL); - - // Set the Q, K and V matrices. - random_init("Q", qkv_h + 0 * d, d, s * b * h, 2 * d + dv, use_1s_q, range_q, scale_q, verbose); - random_init("K", qkv_h + 1 * d, d, s * b * h, 2 * d + dv, use_1s_k, range_k, scale_k, verbose); - random_init("V", qkv_h + 2 * d, dv, s * b * h, 2 * d + dv, use_1s_v, range_v, scale_v, verbose); - // iota_init("Q", qkv_h + 0 * d, d, s * b * h, 3 * d, use_1s_q, range_q, scale_q, verbose, true, - // 0); iota_init("K", qkv_h + 1 * d, d, s * b * h, 3 * d, use_1s_k, range_k, scale_k, verbose, - // true, 128); iota_init("V", qkv_h + 2 * d, d, s * b * h, 3 * d, use_1s_v, range_v, scale_v, - // verbose, true, 256); - - // Multi-query or grouped-query attention for reference input - if (multi_query_attention) { - for (size_t sbi = 0; sbi < s * b; sbi++) { - for (size_t hi = 0; hi < h; hi++) { - for (size_t di = 0; di < d; di++) { - // E.g., h=8, h_kv=4 - // hi: 0, 1, 2, 3, 4, 5, 6, 7 - // hi_kv_scatter: 0, 0, 2, 2, 4, 4, 6, 6 - int const h_per_group = h / h_kv; - int const hi_kv_scatter = (hi / h_per_group) * h_per_group; - size_t src_offset = - sbi * h * 3 * d + hi_kv_scatter * 3 * d + di; // [sbi, hi_kv_scatter, 0, di] - size_t dst_offset = sbi * h * 3 * d + hi * 3 * d + di; // [sbi, hi, 0, di] - - // make sure all heads of kv in a group share the same d - qkv_h[dst_offset + 1 * d] = - qkv_h[src_offset + 1 * d]; // qkv[sbi, hi, 1, di] = qkv[sbi, hi_kv_scatter, 1, di] - qkv_h[dst_offset + 2 * d] = - qkv_h[src_offset + 2 * d]; // qkv[sbi, hi, 2, di] = qkv[sbi, hi_kv_scatter, 2, di] - } - } - } - } - - // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. - // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - float* vt_h = (float*)malloc(o_size * sizeof(float)); - void* vt_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&vt_d, o_size_in_bytes)); - for (size_t it = 0; it < o_size; it++) { - // vt is B x H x D x S - size_t si = it % s; - size_t di = (it / s) % dv; - size_t hi = ((it / s) / dv) % h; - size_t bi = (((it / s) / dv) / h) % b; - // qkv is S x B x H x 3 x D - size_t qkv_idx = si * b * h * (2 * d + dv) + bi * h * (2 * d + dv) + hi * (2 * d + dv) + - 2 * d // index V here - + di; - vt_h[it] = qkv_h[qkv_idx]; - } - FMHA_CHECK_CUDA(cuda_memcpy_h2d(vt_d, vt_h, o_size, data_type)); - - // // DEBUG. - // float sum = 0.f; - // for( size_t si = 0; si < s; ++si ) { - // float v = qkv_h[si*b*h*3*d + 2*d]; - // printf("V[%3d]=%8.3f\n", si, v); - // sum += v; - // } - // printf("Sum of V = %8.3f\n", sum); - // // END OF DEBUG. - - // Copy from the host to the device. - FMHA_CHECK_CUDA(cuda_memcpy_h2d(qkv_sbh3d_d, qkv_h, qkv_size, data_type)); - - // Create the buffer of mask. - // if(verbose) {printf("Init .........: mask\n"); } - // random_init_with_zeroes_or_ones(mask_h, b*s, false, 1.f - dropout, verbose); - - std::vector seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s - std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), [=](const uint32_t) { - if (fix_s) { - return std::min(uint32_t(s), min_s); - } - if (s == min_s) { - return min_s; - } - uint32_t s_ = s - min_s + 1; - uint32_t ret = min_s + (rand() % s_); - assert(ret <= s); - return ret; - }); - - // Compute the prefix sum of the sequence lengths. - std::vector cu_seqlens(b + 1, 0); - for (int it = 0; it < b; it++) { - cu_seqlens[it + 1] = cu_seqlens[it] + seqlens[it]; - } - int total = cu_seqlens.back(); - seqlens.emplace_back(total); - - // Different q and kv sequence lengths. - std::vector q_seqlens = seqlens; - std::vector cu_q_seqlens = cu_seqlens; - if (different_q_kv_lengths) { - for (int it = 0; it < b; it++) { - q_seqlens[it] = s_q; - cu_q_seqlens[it + 1] = cu_q_seqlens[it] + q_seqlens[it]; - } - } - - // Compute the prefix sum of the mask sequence lengths. - std::vector cu_mask_rows(b + 1, 0); - // The mask_h row offset in each sequence to support s_q < s_kv. - // we only need the last s_q rows in the [s, s] mask_h. - std::vector mask_h_row_offsets(b); - for (int it = 0; it < b; it++) { - // The actual q sequence length. - int actual_q_seqlen = q_seqlens[it]; - // The mask_h row offset. - mask_h_row_offsets[it] = seqlens[it] - q_seqlens[it]; - // Round up the sequence length to multiple of 128. - int mask_seqlen = align_to(actual_q_seqlen, fmha::FLASH_ATTEN_MASK_M_ALIGNMENT); - cu_mask_rows[it + 1] = cu_mask_rows[it] + mask_seqlen; - } - - // transfer to device - void *cu_seqlens_d, *cu_q_seqlens_d, *cu_mask_rows_d; - FMHA_CHECK_CUDA(cudaMalloc(&cu_seqlens_d, sizeof(int) * cu_seqlens.size())); - FMHA_CHECK_CUDA(cudaMalloc(&cu_q_seqlens_d, sizeof(int) * cu_q_seqlens.size())); - FMHA_CHECK_CUDA(cudaMalloc(&cu_mask_rows_d, sizeof(int) * cu_mask_rows.size())); - FMHA_CHECK_CUDA(cudaMemcpy(cu_seqlens_d, cu_seqlens.data(), sizeof(int) * cu_seqlens.size(), - cudaMemcpyHostToDevice)); - FMHA_CHECK_CUDA(cudaMemcpy(cu_q_seqlens_d, cu_q_seqlens.data(), sizeof(int) * cu_q_seqlens.size(), - cudaMemcpyHostToDevice)); - FMHA_CHECK_CUDA(cudaMemcpy(cu_mask_rows_d, cu_mask_rows.data(), sizeof(int) * cu_mask_rows.size(), - cudaMemcpyHostToDevice)); - - size_t qkv_packed_size = cu_seqlens.back() * h * (2 * d + dv); - size_t qkv_packed_size_in_bytes = get_size_in_bytes(qkv_packed_size, data_type); - void* qkv_packed_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&qkv_packed_d, qkv_packed_size_in_bytes)); - - // Specify device buffers for multi-query attention or grouped-query attention - // TODO: Use the same buffer for all cases, and allow to set name to aid tracing/debugging - // e.g., - // Buffer qkv_buf(size); - // if( packed ) { qkv_buf.set_name("QKV_packed[total, h, 3, d]"); } - // else { qkv_buf.set_name("QKV_padded[b, s, h, 3, d]"); } - // qkv_buf.copy_to_device(); - // float *qkv_buf_d = qkv_buf.get_device_buf(); - // Or, more aggressively, use torch::Tensor from PyTorch ATen - size_t mqa_qkv_packed_size = cu_seqlens.back() * (h + 2 * h_kv) * d; - size_t mqa_qkv_packed_size_in_bytes = get_size_in_bytes(mqa_qkv_packed_size, data_type); - size_t mqa_qkv_size = b * s * (h + 2 * h_kv) * d; // original padded tensor - size_t mqa_qkv_size_in_bytes = get_size_in_bytes(mqa_qkv_size, data_type); - void* mqa_qkv_packed_d = nullptr; - void* mqa_qkv_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes)); - FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes)); - - const size_t o_packed_size = cu_seqlens.back() * h * dv; - // Allocate on the host. - float* o_packed_h = (float*)malloc(o_packed_size * sizeof(float)); - void* o_packed_d = nullptr; - - size_t o_packed_size_in_bytes = get_size_in_bytes(o_packed_size, output_dtype); - FMHA_CHECK_CUDA(cudaMalloc(&o_packed_d, o_packed_size_in_bytes)); - - // qkv_packed_h is TotalH3D - std::vector qkv_packed_h(qkv_packed_size); - extract_and_transpose_input(qkv_packed_h.data(), qkv_h, seqlens, s, b, h, d, dv, 3, false); - if (interleaved) { - x_vec32(true, qkv_packed_h.data(), h, total, 3); - } - - // qkv_h is SBH3D - // qkv_bsh3d_h is BSH3D - std::vector qkv_bsh3d_h(qkv_size); - extract_and_transpose_input(qkv_bsh3d_h.data(), qkv_h, seqlens, s, b, h, d, dv, 3, - is_s_padded); - if (interleaved) { - x_vec32(true, qkv_bsh3d_h.data(), h, b * h, 3); - } - - std::vector mqa_qkv_packed_h(mqa_qkv_packed_size); - std::vector mqa_qkv_h(mqa_qkv_size); - // for now MLA doesn't use MQA, may enable it in the future - if (d == dv) { - // from qkv[s, h, 3, d] to mqa_qkv[s, h + 2*h_kv, d] - // where - // Q is qkv[s, h, 0, d], - // K is qkv[s, h, 1, d], - // V is qkv[s, h, 2, d] - // and - // MQA_Q is mqa_qkv[s, h, [ 0 : h - 1], d], - // MQA_K is mqa_qkv[s, h, [ h : h + h_kv - 1], d], - // MQA_V is mqa_qkv[s, h, [h + h_kv : h + 2*h_kv - 1], d] - for (size_t si = 0; si < cu_seqlens.back(); si++) { - for (size_t hi = 0; hi < h; hi++) { - for (size_t di = 0; di < d; di++) { - // Q: [si, hi, di] <- [si, hi, 0, di] - mqa_qkv_packed_h[si * (h + 2 * h_kv) * d + hi * d + di] = - qkv_packed_h[si * h * 3 * d + hi * 3 * d + 0 * d + di]; - if (hi < h_kv) { - // E.g., h=8, h_kv=4 - // src kv id: 0, 0, 1, 1, 2, 2, 3, 3 - // hi: 0, 1, 2, 3, 4, 5, 6, 7 - // hi_kv_scatter: 0, 2, 4, 6, x, x, x, x - int const h_per_group = h / h_kv; - int const hi_kv_scatter = hi * h_per_group; - // K: [si, h + hi, di] <- [si, hi_kv_scatter, 1, di] - mqa_qkv_packed_h[si * (h + 2 * h_kv) * d + (h + hi) * d + di] = - qkv_packed_h[si * 3 * h * d + hi_kv_scatter * 3 * d + 1 * d + di]; - // V: [si, h + h_kv + hi, di] <- [si, hi_kv_scatter, 2, di] - mqa_qkv_packed_h[si * (h + 2 * h_kv) * d + (h + h_kv + hi) * d + di] = - qkv_packed_h[si * 3 * h * d + hi_kv_scatter * 3 * d + 2 * d + di]; - } - } - } - } - - // from qkv_bsh3d_h[b, s, h, 3, d] to mqa_qkv[b, s, h + 2*h_kv, d] - for (size_t bi = 0; bi < b; bi++) { - int actual_s = seqlens[bi]; - for (size_t si = 0; si < actual_s; si++) { - for (size_t hi = 0; hi < h; hi++) { - for (size_t di = 0; di < d; di++) { - mqa_qkv_h[bi * s * (h + 2 * h_kv) * d + si * (h + 2 * h_kv) * d + hi * d + di] = - qkv_bsh3d_h[bi * s * h * 3 * d + si * h * 3 * d + hi * 3 * d + 0 * d + di]; - if (hi < h_kv) { - // E.g., h=8, h_kv=4 - // src kv id: 0, 0, 1, 1, 2, 2, 3, 3 - // hi: 0, 1, 2, 3, 4, 5, 6, 7 - // hi_kv_scatter: 0, 2, 4, 6, x, x, x, x - int const h_per_group = h / h_kv; - int const hi_kv_scatter = hi * h_per_group; - mqa_qkv_h[bi * s * (h + 2 * h_kv) * d + si * (h + 2 * h_kv) * d + (h + hi) * d + di] = - qkv_bsh3d_h[bi * s * h * 3 * d + si * h * 3 * d + hi_kv_scatter * 3 * d + 1 * d + - di]; - mqa_qkv_h[bi * s * (h + 2 * h_kv) * d + si * (h + 2 * h_kv) * d + - (h + h_kv + hi) * d + di] = - qkv_bsh3d_h[bi * s * h * 3 * d + si * h * 3 * d + hi_kv_scatter * 3 * d + 2 * d + - di]; - } - } - } - } - } - } - // if( verbose ) { - // print_tensor(qkv_packed_h.data() + 0 * d, d, total * h, 3 * d, "Packed Q[bs, h, d]"); - // print_tensor(qkv_packed_h.data() + 1 * d, d, total * h, 3 * d, "Packed K[bs, h, d]"); - // print_tensor(qkv_packed_h.data() + 2 * d, d, total * h, 3 * d, "Packed V[bs, h, d]"); - - // print_tensor(mqa_qkv_packed_h.data() + 0 * d, h * d, total, (h + 2 * h_kv) * - // d, "Packed MQA Q[bs, h*d]"); print_tensor(mqa_qkv_packed_h.data() + h * d, h_kv - // * d, total, (h + 2 * h_kv) * d, "Packed MQA K[bs, h_kv*d]"); - // print_tensor(mqa_qkv_packed_h.data() + h * d + h_kv * d, h_kv * d, total, (h + 2 - // * h_kv) * d, "Packed MQA V[bs, h_kv*d]"); - - // print_tensor(qkv_bsh3d_h.data() + 0 * d, d, b * h * s, 3 * d, "Padded Q[b, s, h, d]"); - // print_tensor(qkv_bsh3d_h.data() + 1 * d, d, b * h * s, 3 * d, "Padded K[b, s, h, d]"); - // print_tensor(qkv_bsh3d_h.data() + 2 * d, d, b * h * s, 3 * d, "Padded V[b, s, h, d]"); - - // print_tensor(mqa_qkv_h.data() + 0 * d, h * d, b * s, (h + 2 * h_kv) * d, - // "Padded MQA Q[b, s, h*d]"); print_tensor(mqa_qkv_h.data() + h * d, h_kv * d, b * - // s, (h + 2 * h_kv) * d, "Padded MQA K[b, s, h_kv*d]"); print_tensor(mqa_qkv_h.data() + h * d - // + h_kv * d, h_kv * d, b * s, (h + 2 * h_kv) * d, "Padded MQA V[b, s, h_kv*d]"); - // } - - // Contiguous KV Cache and Separate KV Cache. - store_q_and_contiguous_kv_cache(q_d, k_d, v_d, contiguous_kv_h, contiguous_kv_d, - reinterpret_cast(qkv_packed_h.data()), - reinterpret_cast(cu_seqlens.data()), - reinterpret_cast(cu_q_seqlens.data()), b, s, h, h_kv, - d, dv, data_type); - - // Paged KV Cache. - store_paged_kv_cache(kv_cache_ptrs_h, reinterpret_cast(qkv_packed_h.data()), - reinterpret_cast(cu_seqlens.data()), max_blocks_per_seq, - tokens_per_block, b, h, h_kv, d, dv, data_type); - - // Copy packed, padded, mqa packed, mqa padded data buffers - // TODO: use the same buffer for all cases - FMHA_CHECK_CUDA(cuda_memcpy_h2d(qkv_packed_d, qkv_packed_h.data(), qkv_packed_size, data_type)); - FMHA_CHECK_CUDA( - cuda_memcpy_h2d(mqa_qkv_packed_d, mqa_qkv_packed_h.data(), mqa_qkv_packed_size, data_type)); - FMHA_CHECK_CUDA(cuda_memcpy_h2d(mqa_qkv_d, mqa_qkv_h.data(), mqa_qkv_size, data_type)); - FMHA_CHECK_CUDA(cuda_memcpy_h2d(qkv_bsh3d_d, qkv_bsh3d_h.data(), qkv_size, data_type)); - - // Is MTP used? - bool is_mtp = (d == 576 && dv == 512); - - for (size_t so = 0; so < s; ++so) { // s_q - for (size_t bi = 0; bi < b; ++bi) { - int actual_seqlen = seqlens[bi]; - for (size_t si = 0; si < s; ++si) { // s_kv - // Are both the query and the key inside the sequence? - bool valid = (si < actual_seqlen) && (so < actual_seqlen); - // FIXME: add random mask generator. - // attention_mask_type == Attention_mask_type::CUSTOM_MASK - if (attention_mask_type == Attention_mask_type::CUSTOM_MASK || - attention_mask_type == Attention_mask_type::CAUSAL || - attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL) { - valid = valid && (so >= si); - } - if (attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL) { - if (chunked_attention_size > 0) { - int chunk_idx = so / chunked_attention_size; - valid = valid && (si >= (chunk_idx * chunked_attention_size)); - } else { - valid = valid && (si >= std::max(int(so + 1 - sliding_window_size), 0)); - } - } - if (is_mtp) { - // Only the last s_q tokens are used for verifying the results. - size_t idx = so - (actual_seqlen - s_q); - size_t num_mtp_tokens = s_q / num_grouped_heads; - size_t mtp_token_idx = idx / num_grouped_heads; - valid = idx >= 0 && si < (actual_seqlen - num_mtp_tokens + 1 + mtp_token_idx) && - (so < actual_seqlen); - } - if (!skip_checks) { - // The mask is stored as floats. - mask_h[so * b * s + bi * s + si] = valid ? 1.f : 0.f; // mask dims [s_q, b, s_kv] - } - } - } - } - - if (verbose) { - printf("Sequence lengths (first 10 batches): "); - for (int bi = 0; bi < seqlens.size() && bi < 10; bi++) { - printf("%d, ", seqlens[bi]); - } - printf("\n"); - } - - if (v1) { - assert(!interleaved && "Interleaved not supported in v1"); - assert(mmas_n <= 4 && "Not supported"); - - FMHA_CHECK_CUDA(cudaMalloc(&packed_mask_d, packed_mask_size_in_bytes)); - if (sm == 70) { - pack_mask_sm70(packed_mask_h, mask_h, s, b, mmas_m, mmas_n, warps_m, warps_n, - threads_per_cta); - } else { - pack_mask(packed_mask_h, mask_h, s, b, mmas_m, mmas_n, warps_m, warps_n, threads_per_cta); - } - - // Copy the packed mask to the device. - if (!skip_checks) { - FMHA_CHECK_CUDA(cudaMemcpy(packed_mask_d, packed_mask_h, packed_mask_size_in_bytes, - cudaMemcpyHostToDevice)); - } - } else if (attention_mask_type == Attention_mask_type::CUSTOM_MASK) { - FMHA_CHECK_CUDA(cudaMalloc(&packed_mask_d, packed_mask_size_in_bytes)); - assert(fmha::FLASH_ATTEN_MASK_MMA_M == warps_m * 16 && "Not supported"); - assert(fmha::FLASH_ATTEN_MASK_MMA_N / 8 == 8 && "Not supported"); - pack_flash_attention_mask(packed_mask_h, mask_h, b, s, warps_m, warps_n, threads_per_cta, - mmas_n, fmha::FLASH_ATTEN_MASK_MMA_N / 8, mask_h_row_offsets.data(), - cu_mask_rows.data()); - - // Copy the packed mask to the device. - FMHA_CHECK_CUDA(cudaMemcpy(packed_mask_d, packed_mask_h, packed_mask_size_in_bytes, - cudaMemcpyHostToDevice)); - } - - // Copy the mask to the device. - if (!skip_checks) { - FMHA_CHECK_CUDA(cuda_memcpy_h2d(mask_d, mask_h, mask_size, DATA_TYPE_INT8)); - } - - // non-owning pointer to the IO buffer - void* qkv_d_view = nullptr; - void* o_d_view = nullptr; - int o_view_size = 0; - if (is_s_padded) { - qkv_d_view = multi_query_attention ? mqa_qkv_d : qkv_bsh3d_d; - o_d_view = o_d; - o_view_size = o_size; - } else { - qkv_d_view = multi_query_attention ? mqa_qkv_packed_d : qkv_packed_d; - o_d_view = o_packed_d; - o_view_size = o_packed_size; - } - void* softmax_stats_ptr = save_softmax ? softmax_stats_d : nullptr; - // Set the params. - bert::Fused_multihead_attention_params_v1 params_v1; - printf("=== set_params() arguments ===\n"); - printf("launch_params: ...\n"); // For struct, maybe print pointer or describe - printf("data_type: %d\n", int(data_type)); - printf("acc_type: %d\n", int(acc_type)); - printf("output_dtype: %d\n", int(output_dtype)); - printf("input_layout: %d\n", int(input_layout)); - printf("b: %zu\n", size_t(b)); - printf("s_q: %zu\n", size_t(s_q)); - printf("s: %zu\n", size_t(s)); - printf("h: %zu\n", size_t(h)); - printf("h_kv: %zu\n", size_t(h_kv)); - printf("d: %zu\n", size_t(d)); - printf("dv: %zu\n", size_t(dv)); - printf("total: %zu\n", size_t(total)); - printf("num_grouped_heads: %zu\n", size_t(num_grouped_heads)); - printf("sliding_window_size: %zu\n", size_t(sliding_window_size)); - printf("chunked_attention_size: %zu\n", size_t(chunked_attention_size)); - printf("tokens_per_block: %zu\n", size_t(tokens_per_block)); - printf("qkv_d_view: %p\n", qkv_d_view); - printf("q_d: %p\n", q_d); - printf("k_d: %p\n", k_d); - printf("v_d: %p\n", v_d); - printf("contiguous_kv_d: %p\n", contiguous_kv_d); - printf("kv_cache_pool_ptr: %p\n", kv_cache_pool_ptr); - printf("kv_cache_block_offsets_d: %p\n", kv_cache_block_offsets_d); - printf("packed_mask_d: %p\n", packed_mask_d); - printf("cu_mask_rows_d: %p\n", cu_mask_rows_d); - printf("attention_sinks_d: %p\n", attention_sinks_d); - printf("cu_seqlens_d: %p\n", cu_seqlens_d); - printf("cu_q_seqlens_d: %p\n", cu_q_seqlens_d); - printf("o_d_view: %p\n", o_d_view); - printf("p_d: %p\n", p_d); - printf("s_d: %p\n", s_d); - printf("softmax_stats_ptr: %p\n", softmax_stats_ptr); - printf("scale_bmm2_d: %p\n", scale_bmm2_d); - printf("scale_bmm1: %f\n", scale_bmm1); - printf("scale_softmax: %f\n", scale_softmax); - printf("scale_bmm2: %f\n", scale_bmm2); - printf("softcapping_scale_bmm1: %f\n", softcapping_scale_bmm1); - printf("use_int8_scale_max: %d\n", int(use_int8_scale_max)); - printf("interleaved: %d\n", int(interleaved)); - printf("is_s_padded: %d\n", int(is_s_padded)); - printf("has_alibi: %d\n", int(has_alibi)); - printf("=============================\n"); - set_params(params_v1, data_type, acc_type, b, s, h, d, mmas_m * threads_per_cta, qkv_sbh3d_d, - packed_mask_d, o_d, p_d, s_d, scale_bmm1, scale_softmax, scale_bmm2, has_alibi); - - bert::Fused_multihead_attention_params_v2 params_v2; - set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s, - h, h_kv, d, dv, total, num_grouped_heads, sliding_window_size, chunked_attention_size, - // Paged kv cache. - tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, - kv_cache_block_offsets_d, packed_mask_d, cu_mask_rows_d, attention_sinks_d, - cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr, scale_bmm2_d, - scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, - interleaved, is_s_padded, has_alibi); - - // total number of tokens is needed to set TMA desc on the host. - launch_params.total_q_seqlen = q_seqlens[b]; - launch_params.total_kv_seqlen = seqlens[b]; - // set enable_attn_logit_softcapping to select the right kernel. - launch_params.enable_attn_logit_softcapping = softcapping_scale_bmm1 != 0.f; - - // Allocate barriers and locks. - void* counters_d = nullptr; - if (ctas_per_head > 1) { - size_t sz = heads_per_wave * sizeof(int); - FMHA_CHECK_CUDA(cudaMalloc((void**)&counters_d, 3 * sz)); - } - - // Allocate scratch storage for softmax. - void *max_scratch_d = nullptr, *sum_scratch_d = nullptr; - if (ctas_per_head > 1) { - size_t sz = heads_per_wave * ctas_per_head * threads_per_cta * sizeof(float); - FMHA_CHECK_CUDA(cudaMalloc((void**)&max_scratch_d, sz)); - FMHA_CHECK_CUDA(cudaMalloc((void**)&sum_scratch_d, sz)); - } - - // Allocate temporary storage for the parallel reduction. - void* o_scratch_d = nullptr; - if (ctas_per_head > 1 && data_type != DATA_TYPE_FP16) { - size_t sz = heads_per_wave * threads_per_cta * MAX_STGS_PER_LOOP * sizeof(uint4); - FMHA_CHECK_CUDA(cudaMalloc((void**)&o_scratch_d, sz)); - } - - // Allocate tile id for dynamic scheduling - void* tile_id_counter_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc((void**)&tile_id_counter_d, sizeof(uint32_t))); - - // The number of heads computed per wave. - params_v1.heads_per_wave = heads_per_wave; - params_v2.heads_per_wave = heads_per_wave; - - // Barriers for the global sync in the multi-CTA kernel(s). - params_v1.counters = (int*)counters_d + 0 * heads_per_wave; - params_v2.counters = (int*)counters_d + 0 * heads_per_wave; - params_v1.max_barriers = (int*)counters_d + 0 * heads_per_wave; - params_v2.max_barriers = (int*)counters_d + 0 * heads_per_wave; - params_v1.sum_barriers = (int*)counters_d + 1 * heads_per_wave; - params_v2.sum_barriers = (int*)counters_d + 1 * heads_per_wave; - params_v1.locks = (int*)counters_d + 2 * heads_per_wave; - params_v2.locks = (int*)counters_d + 2 * heads_per_wave; - - // Scratch storage for softmax. - params_v1.max_scratch_ptr = (float*)max_scratch_d; - params_v2.max_scratch_ptr = (float*)max_scratch_d; - params_v1.sum_scratch_ptr = (float*)sum_scratch_d; - params_v2.sum_scratch_ptr = (float*)sum_scratch_d; - - // Scratch storage for output. - params_v1.o_scratch_ptr = (int*)o_scratch_d; - params_v2.o_scratch_ptr = (int*)o_scratch_d; - - // Tile id counter for dynamic scheduling - params_v2.tile_id_counter_ptr = (uint32_t*)tile_id_counter_d; - // params_paged_v2.tile_id_counter_ptr = (uint32_t*) tile_id_counter_d; - - if (sage_block_size_q > 0 || sage_block_size_k > 0 || sage_block_size_v > 0) { - assert(input_layout == Attention_input_layout::PACKED_QKV && - "for now this test only supports PACKED_QKV"); - assert(d == dv && "for now SageAttention doesn't support different QKV dims"); - assert(((sm == 90 && !force_non_warp_specialization) || (sm == 89)) && - "only hopper and ada kernels support SageAttention"); - fmha::e4m3_t* quant_qkv; - FMHA_CHECK_CUDA(cudaMalloc((void**)&quant_qkv, qkv_packed_size)); - params_v2.sage.q.block_size = sage_block_size_q; - params_v2.sage.q.max_nblock = (s + sage_block_size_q - 1) / sage_block_size_q; - FMHA_CHECK_CUDA(cudaMalloc((void**)¶ms_v2.sage.q.scales, - params_v2.sage.q.max_nblock * h * b * sizeof(float))); - params_v2.sage.k.block_size = sage_block_size_k; - params_v2.sage.k.max_nblock = (s + sage_block_size_k - 1) / sage_block_size_k; - FMHA_CHECK_CUDA(cudaMalloc((void**)¶ms_v2.sage.k.scales, - params_v2.sage.k.max_nblock * h * b * sizeof(float))); - params_v2.sage.v.block_size = sage_block_size_v; - params_v2.sage.v.max_nblock = (s + sage_block_size_v - 1) / sage_block_size_v; - FMHA_CHECK_CUDA(cudaMalloc((void**)¶ms_v2.sage.v.scales, - params_v2.sage.v.max_nblock * h * b * sizeof(float))); -#if 1 - { - // simple test, all scales are the same - constexpr float const_scale = 0.618f; - fmha::e4m3_t* quant_qkv_h = (fmha::e4m3_t*)malloc(qkv_packed_size); - for (size_t i = 0; i < qkv_packed_size; i++) { - quant_qkv_h[i] = fmha::e4m3_t(qkv_packed_h[i] / const_scale); - } - FMHA_CHECK_CUDA(cudaMemcpy(quant_qkv, quant_qkv_h, qkv_packed_size, cudaMemcpyHostToDevice)); - free(quant_qkv_h); - auto init_scales = [&](bert::Fused_multihead_attention_params_v2::SageAttention::Scales& x) { - std::vector scales(x.max_nblock * h * b, const_scale); - FMHA_CHECK_CUDA(cudaMemcpy(x.scales, scales.data(), sizeof(float) * scales.size(), - cudaMemcpyHostToDevice)); - }; - init_scales(params_v2.sage.q); - init_scales(params_v2.sage.k); - init_scales(params_v2.sage.v); - } -#else - { - // use external quant kernel - run_sage_quant(b, h, d, s, params_v2.qkv_ptr, - (char*) params_v2.qkv_ptr + get_size_in_bytes(h * d, data_type), - (char*) params_v2.qkv_ptr + get_size_in_bytes(2 * h * d, data_type, - params_v2.q_stride_in_bytes, - params_v2.k_stride_in_bytes, - params_v2.v_stride_in_bytes, - params_v2.cu_q_seqlens, params_v2.cu_kv_seqlens, sage_block_size_q, sage_block_size_k, - sage_block_size_v, quant_qkv, quant_qkv + h * d, quant_qkv + 2 * h * d, params_v2.sage.q.scales, - params_v2.sage.k.scales, params_v2.sage.v.scales); - } -#endif - // no need to free old params_v2.qkv_ptr, it will be released in the end - params_v2.qkv_ptr = quant_qkv; - params_v2.q_stride_in_bytes = params_v2.k_stride_in_bytes = params_v2.v_stride_in_bytes = - get_size_in_bytes((h + 2 * h_kv) * d, DATA_TYPE_E4M3); - } - -#if defined(DEBUG_HAS_PRINT_BUFFER) - auto& params = params_v2; - constexpr size_t bytes = 32 * 1024; - void* print_ptr = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(¶ms.print_ptr, bytes)); - std::vector print_buffer(bytes / sizeof(float)); -#endif - // Run a few warm-up kernels. - for (int ii = 0; ii < warm_up_runs; ++ii) { - if (v1) { - run_fmha_v1(params_v1, launch_params, data_type, output_dtype, sm, 0); - } else { - run_fmha_v2(params_v2, launch_params, data_type, output_dtype, sm, 0); - } - } - printf("Warm-up kernels done\n"); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - - float non_fused_elapsed = INFINITY; - printf("Running reference kernel\n"); - if (!skip_checks) { - // Run cuBLAS. - - RefBMM bmm1(data_type_to_cuda(data_type), // a - data_type_to_cuda(data_type), // b - data_type_to_cuda(acc_type), // d - data_type_to_cublas(acc_type), // compute - data_type_to_cuda(acc_type), // scale - false, // Q - true, // K' - s, // m - s, // n - d, // k - b * h * (2 * d + dv), // ld Q - b * h * (2 * d + dv), // ld K - b * h * s, // ld P - (2 * d + dv), // stride Q - (2 * d + dv), // stride K - s, // stride P - b * h // batch count - ); - - /* - RefBMM bmm2(data_type_to_cuda(data_type), // a - data_type_to_cuda(data_type), // b - data_type_to_cuda(acc_type), // d - data_type_to_cublas(acc_type), //compute - data_type_to_cuda(acc_type), // scale - false, // S - false, // V - s, // m - d, // n - s, // k - b * h * s, // ld S - b * h * 3 * d, // ld V - b * h * d, // ld O - s, // stride S - 3 * d, // stride V - d, // stride O - b * h // batch count - ); - */ - - // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. - // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - RefBMM bmm2(data_type_to_cuda(data_type), // a - data_type_to_cuda(data_type), // b - data_type_to_cuda(acc_type), // d - data_type_to_cublas(acc_type), // compute - data_type_to_cuda(acc_type), // scale - false, // S - true, // V' - s, // m - dv, // n - s, // k - b * h * s, // ld S - s, // ld V - b * h * dv, // ld O - s, // stride S - s * dv, // stride V - dv, // stride O - b * h // batch count - ); - timer.start(); - ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, - softcapping_scale_bmm1, qkv_sbh3d_d, - vt_d, // WAR pass in V' - mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, - s, h, d, dv, runs, warps_m, warps_n, has_alibi); - timer.stop(); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - FMHA_CHECK_CUDA(cudaDeviceSynchronize()); - non_fused_elapsed = timer.millis(); - -#if defined(STORE_P) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(p_ref_h, p_d, p_size, acc_type)); -#endif // defined(STORE_P) - -#if defined(STORE_S) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(s_ref_h, s_d, p_size, data_type)); -#endif // defined(STORE_S) - - // Read the results. - FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_ref_h, o_d, o_size, data_type)); - FMHA_CHECK_CUDA( - cuda_memcpy_d2h(softmax_stats_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); - } - - // Fill-in p/s/o with garbage data. - // WAR: if sequence is padded, we zero-fill the output buffer as kernel will not write to the - // padded area, and the host expects to check the padded area - if (!skip_checks) { - FMHA_CHECK_CUDA(cudaMemset(p_d, 0xdc, p_size_in_bytes)); - FMHA_CHECK_CUDA(cudaMemset(s_d, 0xdc, s_size_in_bytes)); - } - FMHA_CHECK_CUDA(cudaMemset(o_d, 0x00, o_size_in_bytes)); - FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * b * s * h * sizeof(float))); - - // Run the kernel. - timer.start(); - for (int ii = 0; ii < runs; ++ii) { - if (v1) { - run_fmha_v1(params_v1, launch_params, data_type, output_dtype, sm, 0); - } else { - run_fmha_v2(params_v2, launch_params, data_type, output_dtype, sm, 0); - } - } - timer.stop(); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - - FMHA_CHECK_CUDA(cudaDeviceSynchronize()); - float fused_elapsed = timer.millis(); - -#if defined(STORE_P) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(p_h, p_d, p_size, acc_type)); - printf("\nChecking .....: P = norm * K^T * Q\n"); - - // DEBUG. - printf("seqlens[0]=%d\n", seqlens[0]); - // END OF DEBUG. - - // Clear the invalid region of P. - set_mat(p_ref_h, seqlens, s, b, h, s, 0.f, true); - set_mat(p_h, seqlens, s, b, h, s, 0.f, true); - - // Do the check. - check_results(p_h, p_ref_h, s, s * b * h, s, 0.f, true, true); -#endif // defined(STORE_P) - -#if defined(STORE_S) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(s_h, s_d, p_size, data_type)); - printf("\nChecking .....: S = softmax(P)\n"); -#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - float softmax_epsilon = data_type == DATA_TYPE_FP16 ? 1e-3f : 0.f; -#else - float softmax_epsilon = 1.e-3f; -#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - - // Clear the invalid region of S. - set_mat(s_ref_h, seqlens, s, b, h, s, 0.f); - set_mat(s_h, seqlens, s, b, h, s, 0.f); - - // Do the check. - check_results(s_h, s_ref_h, s, s * b * h, s, softmax_epsilon, true, true); -#endif // defined(STORE_S) - - // Check the final results. - int status = -1; - if (skip_checks) { - status = 0; - printf("\n"); - print_results(true, false); - } else { - if (v1) { - FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_h, o_d, o_size, output_dtype)); - status = check_results(o_h, o_ref_h, d, s * b * h, d, epsilon, verbose, true); - } else { - std::vector o_ref_trans_h(o_size); - - FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_h, o_d_view, o_view_size, output_dtype)); - FMHA_CHECK_CUDA( - cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); - - if (interleaved) { - // revert batch-interleaved format: 3 x h/32 x total x d x 32 => total x - // h x 3 x d - x_vec32(false, o_h, h, is_s_padded ? b * h : total, 1); - } - - // Extract the last s_q tokens from the output. - extract_and_transpose_output(o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, - b, h, dv, is_s_padded); - if (verbose) { - printf("\nChecking .....: O = V * S\n"); - } - status = check_results(o_h, o_ref_trans_h.data(), dv, - is_s_padded ? s_q * b * h : cu_q_seqlens.back() * h, dv, epsilon, - verbose, true); - if (save_softmax) { - auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, - cu_seqlens); - status = status | ((errors.first + errors.second) > 0); - } - } - if (status != 0) { // if there was an error, print the config of the run - printf("v1=%d il=%d s=%lu b=%lu h=%lu dv=%lu dtype=%s\n", v1, interleaved, s, b, h, dv, - data_type_to_name(data_type).c_str()); - } - if (!verbose) { // this just prints the SUCCESS/ERROR line - print_results(true, true, status == 0); - } - } - - // accounts for tensor core flops only; excludes flops spent in softmax - size_t total_flops = 0; - // remove last seqlen(total_seqlen) - seqlens.pop_back(); - for (auto& s_ : seqlens) { - size_t s_size = size_t(s_); - total_flops += 2ull * h * (s_q * s_size * d + s_q * dv * s_size); // 1st BMM + 2nd BMM - } - total_flops = attention_mask_type == Attention_mask_type::CAUSAL ? total_flops / 2 : total_flops; - - size_t total_bytes = o_packed_size_in_bytes + qkv_packed_size_in_bytes; - if (verbose) { - // Runtimes. - printf("\n"); - if (!skip_checks) { - printf("Non-fused time: %.6f ms\n", non_fused_elapsed / float(runs)); - } - printf("Fused time ...: %.6f us\n", fused_elapsed * 1000 / float(runs)); - printf("Tensor core ..: %.2f Tflop/s\n", total_flops / (fused_elapsed / float(runs) / 1e-9)); - printf("Bandwidth ....: %.2f GB/s\n", total_bytes / (fused_elapsed / float(runs) / 1e-6)); - if (!skip_checks) { - printf("Ratio ........: %.2fx\n", non_fused_elapsed / fused_elapsed); - } - } else { - printf("Elapsed ......: %.6f us (%.2fx), %.2f Tflop/s, %.2f GB/s\n", - fused_elapsed * 1000 / float(runs), non_fused_elapsed / fused_elapsed, - total_flops / (fused_elapsed / float(runs) / 1e-9), - total_bytes / (fused_elapsed / float(runs) / 1e-6)); - } -#if defined(DEBUG_HAS_PRINT_BUFFER) - FMHA_CHECK_CUDA( - cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32)); - - printf("\n====================\n"); - for (int it = 0; it < 16; it++) { - printf("% .4f ", print_buffer[it]); - } - printf("\n====================\n"); - - FMHA_CHECK_CUDA(cudaFree(params.print_ptr)); - -#endif - // Release memory. - FMHA_CHECK_CUDA(cudaFree(qkv_sbh3d_d)); - FMHA_CHECK_CUDA(cudaFree(qkv_packed_d)); - FMHA_CHECK_CUDA(cudaFree(scale_bmm2_d)); - FMHA_CHECK_CUDA(cudaFree(mqa_qkv_d)); - FMHA_CHECK_CUDA(cudaFree(mqa_qkv_packed_d)); - FMHA_CHECK_CUDA(cudaFree(qkv_bsh3d_d)); - FMHA_CHECK_CUDA(cudaFree(mask_d)); - FMHA_CHECK_CUDA(cudaFree(packed_mask_d)); - FMHA_CHECK_CUDA(cudaFree(q_d)); - FMHA_CHECK_CUDA(cudaFree(k_d)); - FMHA_CHECK_CUDA(cudaFree(v_d)); - FMHA_CHECK_CUDA(cudaFree(p_d)); - FMHA_CHECK_CUDA(cudaFree(s_d)); - FMHA_CHECK_CUDA(cudaFree(o_d)); - FMHA_CHECK_CUDA(cudaFree(tmp_d)); - FMHA_CHECK_CUDA(cudaFree(cu_seqlens_d)); - FMHA_CHECK_CUDA(cudaFree(cu_mask_rows_d)); - FMHA_CHECK_CUDA(cudaFree(max_scratch_d)); - FMHA_CHECK_CUDA(cudaFree(sum_scratch_d)); - FMHA_CHECK_CUDA(cudaFree(o_scratch_d)); - FMHA_CHECK_CUDA(cudaFree(counters_d)); - FMHA_CHECK_CUDA(cudaFree(tile_id_counter_d)); - FMHA_CHECK_CUDA(cudaFree(kv_cache_pool_ptr)); - FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d)); - FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d)); - FMHA_CHECK_CUDA(cudaFree(softmax_stats_d)); - - free(qkv_h); - free(mask_h); - free(packed_mask_h); - free(s_h); - free(o_h); - free(o_ref_h); - free(softmax_stats_h); - free(softmax_stats_ref_h); - free(contiguous_kv_h); - free(kv_cache_ptrs_h); - free(kv_cache_block_offsets_h); - - free(p_ref_h); -#if defined(STORE_P) - free(p_h); -#endif // defined(STORE_P) - free(s_ref_h); - - return status; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/fmha_v2/fused_multihead_attention.h b/csrc/fmha_v2/fused_multihead_attention.h index c1653bb5bb..7049103d7f 100644 --- a/csrc/fmha_v2/fused_multihead_attention.h +++ b/csrc/fmha_v2/fused_multihead_attention.h @@ -281,6 +281,16 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba float* scales; } q, k, v; } sage; + + // Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen. + // A positive value means skip-softmax is enabled. + float skip_softmax_threshold_scale_factor = 0; + +#ifdef SKIP_SOFTMAX_STAT + // Statistics of skip-softmax, pointers of device memory for output + uint32_t* skip_softmax_total_blocks; + uint32_t* skip_softmax_skipped_blocks; +#endif }; #endif @@ -319,6 +329,8 @@ struct Fused_multihead_attention_launch_params { // harward properties to determine how to launch blocks int multi_processor_count = 0; int device_l2_cache_size = 0; + // skip softmax attention + bool enable_skip_softmax = false; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h b/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h index bfe40b720f..62294e2f0a 100644 --- a/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h +++ b/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h @@ -168,4 +168,13 @@ struct Fused_multihead_attention_params_v2 { float* scales; } q, k, v; } sage; + + // Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen. + // A positive value means skip-softmax is enabled. + float skip_softmax_threshold_scale_factor = 0; +#ifdef SKIP_SOFTMAX_STAT + // Statistics of skip-softmax, pointers of device memory for output + uint32_t* skip_softmax_total_blocks; + uint32_t* skip_softmax_skipped_blocks; +#endif }; diff --git a/csrc/fmha_v2/fused_multihead_cross_attention.cpp b/csrc/fmha_v2/fused_multihead_cross_attention.cpp deleted file mode 100644 index cc6b7548be..0000000000 --- a/csrc/fmha_v2/fused_multihead_cross_attention.cpp +++ /dev/null @@ -1,939 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include -#include -#include - -#include -#include -#include -#include -#include - -using Launch_params = bert::Fused_multihead_attention_launch_params; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_seqlens_q_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_seqlens_q_d, int s_inner, int s_outer, int b, - int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_seqlens_q_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_seqlens_q_d, int s_inner, int s_outer, int b, - int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, - int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, - float scale); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, - float scale_o); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, - float const scale_bmm1, float const scale_softmax, float const scale_bmm2, - void* q_d, void* kv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, - void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d, const size_t b, - const size_t s_q, const size_t s_kv, const size_t h, const size_t d, - int const runs, int const warps_m, int const warps_n, bool has_alibi) { - cudaStream_t stream = 0; - // The stride between rows of the QKV matrix. - size_t qkv_stride = get_size_in_bytes(d, data_type); - - // 1st GEMMd. - uint32_t alpha, beta = 0u; - - for (int ii = 0; ii < runs; ++ii) { - // If we run the INT8 kernel, defer the scaling of P to softmax. - set_alpha(alpha, data_type == DATA_TYPE_INT8 ? 1.f : scale_bmm1, acc_type); - - // P = Q x K' - bmm1(static_cast(q_d) + 0 * qkv_stride, static_cast(kv_d) + 0 * qkv_stride, p_d, - &alpha, &beta, stream); - - // Softmax. - if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, - 0.f, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, - 0.f, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, - scale_softmax, 0.f, warps_n, has_alibi); - } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, - scale_bmm1, scale_softmax, 0.f, warps_n, has_alibi); - } else { - assert(false && "Reference Softmax: Unsupported type config"); - } - - // 2nd GEMM. - set_alpha(alpha, 1.f, acc_type); - - void* out_d = o_d; - - // We may have to do a final conversion. - if (data_type != acc_type) { - out_d = tmp_d; - } - - // O = S x V - bmm2(static_cast(s_d), - static_cast(vt_d), // static_cast(qkv_d) + 2 * qkv_stride, - out_d, &alpha, &beta, stream); - - // Conversion to output type. - if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - // Noop. - } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_conversion_fp32_to_fp16(o_d, out_d, s_q, b, h, d); - } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_conversion_fp32_to_e4m3(o_d, out_d, s_q, b, h, d, scale_bmm2); - } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - // quantize output in second step - run_conversion_int32_to_int8(o_d, out_d, s_q, b, h, d, scale_bmm2); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline void set_params(bert::Fused_multihead_attention_params_mhca& params, - // types - Data_type data_type, Data_type acc_type, - // sizes - const size_t b, const size_t s_q, const size_t s_kv, const size_t h, - const size_t d, const size_t d_padded, const size_t total, - // device pointers - void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, - void* cu_seqlens_kv_d, void* o_packed_d, void* p_d, void* s_d, - // scale factors - float const scale_bmm1, float const scale_softmax, - float const scale_bmm2, - // flags - bool const use_int8_scale_max) { - memset(¶ms, 0, sizeof(params)); - - // Set the pointers. - params.o_ptr = o_packed_d; - params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); - - // if( interleaved ) { - // params.qkv_stride_in_bytes = total; - // params.o_stride_in_bytes = total; - // } - -#if defined(STORE_P) - params.p_ptr = p_d; - params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type); -#endif // defined(STORE_P) - -#if defined(STORE_S) - params.s_ptr = s_d; - params.s_stride_in_bytes = get_size_in_bytes(b * h * s_kv, data_type); -#endif // defined(STORE_S) - - // Set the dimensions. - params.b = b; - params.h = h; - params.s_q = s_q; - params.s = s_kv; - params.d = d; - params.d_padded = d_padded; - - // Set the different scale values. - Data_type scale_type1 = data_type == DATA_TYPE_FP16 ? acc_type : DATA_TYPE_FP32; - Data_type scale_type2 = data_type == DATA_TYPE_FP16 ? DATA_TYPE_FP16 : DATA_TYPE_FP32; - - set_alpha(params.scale_bmm1, scale_bmm1, scale_type1); - set_alpha(params.scale_softmax, scale_softmax, scale_type1); - set_alpha(params.scale_bmm2, scale_bmm2, scale_type2); - - // Set the pointers. - params.gmem_q_params.ptr = q_packed_d; - params.gmem_q_params.stride_in_bytes = get_size_in_bytes(h * d, data_type); - params.gmem_q_params.h = h; - params.gmem_q_params.d = d; - params.gmem_q_params.cu_seqlens = static_cast(cu_seqlens_q_d); - - params.gmem_kv_params.ptr = kv_packed_d; - params.gmem_kv_params.stride_in_bytes = get_size_in_bytes(h * 2 * d, data_type); - params.gmem_kv_params.h = h; - params.gmem_kv_params.d = d; - params.gmem_kv_params.cu_seqlens = static_cast(cu_seqlens_kv_d); - - // Set flags - params.use_int8_scale_max = use_int8_scale_max; - - // Do we enable the trick to replace I2F with FP math in the 2nd GEMM? - if (data_type == DATA_TYPE_INT8) { - params.enable_i2f_trick = -double(1 << 22) * double(scale_bmm2) <= -128.f && - double(1 << 22) * double(scale_bmm2) >= 127.f; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -int main(int argc, char** argv) { - // The device. Reset on destruction - CudaDevice device; - int sm = device.sm; - cudaDeviceProp props = device.props; - - GpuTimer timer; - - // The batch size. - size_t b = 128; - // The number of heads. - size_t h = 16; - // The dimension of the Q, K and V vectors. - size_t d = 64; - // The length of the sequence for query tokens - size_t s_q = 4096; - // The length of the sequence for K/V cross attention tokens - size_t s_kv = 77; - - // The data type of the kernel. - Data_type data_type = DATA_TYPE_FP16; - // The type of the intermediate P matrix. - Data_type acc_type = DATA_TYPE_FP16; - // The scaling factors. - float scale_bmm1 = 0.f, scale_softmax = 0.f, scale_bmm2 = 0.25f; - // The number of runs. - int runs = 1, warm_up_runs = 0; - // Do we use 1s for Q, K, V. - bool use_1s_q = false, use_1s_k = false, use_1s_v = false, use_1s_mask = false; - // The range of the different inputs. - int range_q = 5, range_k = 3, range_v = 5; - // The scale. - float scale_q = 0.f, scale_k = 0.f, scale_v = 0.f; - // The threshold for dropout. By default, drop 10%. - float dropout = 0.1f; - // Do we skip the checks. - bool skip_checks = false; - // The tolerance when checking results. - float epsilon = -1.f; // data_type == DATA_TYPE_FP16 ? 0.015f : 0.f; - - // minimum sequence length for sampling variable seqlens - uint32_t min_s = s_q; - - // run interleaved kernels and transpose input and output accordingly - bool interleaved = false; - bool ignore_b1opt = false; - bool force_unroll = true; - bool use_int8_scale_max = false; - bool verbose = true; - - // set all sequence lengths to min(s, min_s) - bool fix_s = true; - - bool v1 = false; - - // use TMA or not. ignored if not in SM90 - bool use_tma = false; - - // Read the parameters from the command-line. - for (int ii = 1; ii < argc; ++ii) { - if (!strcmp(argv[ii], "-1s")) { - use_1s_k = use_1s_q = use_1s_v = use_1s_mask = true; - } else if (!strcmp(argv[ii], "-1s-k")) { - use_1s_k = true; - } else if (!strcmp(argv[ii], "-1s-mask")) { - use_1s_mask = true; - } else if (!strcmp(argv[ii], "-1s-q")) { - use_1s_q = true; - } else if (!strcmp(argv[ii], "-1s-v")) { - use_1s_v = true; - } else if (!strcmp(argv[ii], "-b") && ++ii < argc) { - b = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-d") && ++ii < argc) { - d = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-dropout") && ++ii < argc) { - dropout = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-epsilon") && ++ii < argc) { - epsilon = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-h") && ++ii < argc) { - h = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-int8")) { - data_type = DATA_TYPE_INT8; - acc_type = DATA_TYPE_INT32; - } else if (!strcmp(argv[ii], "-fp16")) { - data_type = DATA_TYPE_FP16; - acc_type = DATA_TYPE_FP16; - } else if (!strcmp(argv[ii], "-e4m3")) { - data_type = DATA_TYPE_E4M3; - // Technically not the acc type. - acc_type = DATA_TYPE_FP32; - } else if (!strcmp(argv[ii], "-range-k") && ++ii < argc) { - range_k = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-range-q") && ++ii < argc) { - range_q = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-range-v") && ++ii < argc) { - range_v = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-runs") && ++ii < argc) { - runs = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-s-q") && ++ii < argc) { - s_q = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-s-kv") && ++ii < argc) { - s_kv = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-scale-bmm1") && ++ii < argc) { - scale_bmm1 = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-bmm2") && ++ii < argc) { - scale_bmm2 = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-k") && ++ii < argc) { - scale_k = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-softmax") && ++ii < argc) { - scale_softmax = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-q") && ++ii < argc) { - scale_q = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-scale-v") && ++ii < argc) { - scale_v = (float)strtod(argv[ii], nullptr); - } else if (!strcmp(argv[ii], "-skip-checks")) { - skip_checks = true; - } else if (!strcmp(argv[ii], "-warm-up-runs") && ++ii < argc) { - warm_up_runs = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-min-s") && ++ii < argc) { - min_s = strtol(argv[ii], nullptr, 10); - } else if (!strcmp(argv[ii], "-il")) { - interleaved = true; - } else if (!strcmp(argv[ii], "-ignore-b1opt")) { - ignore_b1opt = true; - } else if (!strcmp(argv[ii], "-scale-max")) { - use_int8_scale_max = true; - } else if (!strcmp(argv[ii], "-v") && ++ii < argc) { - int v = strtol(argv[ii], nullptr, 10); - verbose = v != 0; - } else if (!strcmp(argv[ii], "-use-tma")) { - use_tma = true; - } else { - fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]); - return -1; - } - } - - if (interleaved) { - throw std::runtime_error("Interleaved layout is not supported!"); - } - - min_s = std::min(s_q, min_s); - - // The padded sizes. - int const s_kv_padded = std::pow(2, std::ceil(std::log(s_kv) / std::log(2))); - int const d_padded = std::pow(2, std::ceil(std::log(d) / std::log(2))); - - // Set the norm. - if (scale_bmm1 == 0.f) { - scale_bmm1 = 1.f / sqrtf((float)d); - } - - // Force the softmax scale to 1.f for the FP16 kernel. - if (data_type == DATA_TYPE_FP16) { - scale_softmax = 1.f; - } else if (data_type == DATA_TYPE_INT8 && scale_softmax == 0.f) { - scale_softmax = std::max(512.f, (float)s_kv); - } - - // Define the scaling factor for the different inputs. - if (scale_q == 0.f) { - scale_q = 1.f; - } - if (scale_k == 0.f) { - scale_k = 1.f; - } - if (scale_v == 0.f) { - scale_v = data_type == DATA_TYPE_FP16 ? 0.125f : 1.f; - } - - // Set the tolerance if not already set by the user. - if (epsilon < 0.f) { - epsilon = data_type == DATA_TYPE_FP16 ? 0.015f : 0.f; - } - - // Debug info -- only in verbose mode. - if (verbose) { - // Running the following command. - printf("Command.......: %s", argv[0]); - for (int ii = 1; ii < argc; ++ii) { - printf(" %s", argv[ii]); - } - printf("\n"); - - // Device info. - printf("Device........: %s\n", props.name); - printf("Arch.(sm).....: %d\n", sm); - printf("#.of.SMs......: %d\n", props.multiProcessorCount); - - // Problem info. - printf("Batch ........: %lu\n", b); - printf("Heads ........: %lu\n", h); - printf("Dimension ....: %lu\n", d); - printf("Seq len Q ....: %lu\n", s_q); - printf("Seq len KV ...: %lu\n", s_kv); - printf("Warm-up runs .: %d\n", warm_up_runs); - printf("Runs..........: %d\n\n", runs); - - // The scaling factors for the 3 operations. - printf("Scale bmm1 ...: %.6f\n", scale_bmm1); - printf("Scale softmax.: %.6f\n", scale_softmax); - printf("Scale bmm2 ...: %.6f\n", scale_bmm2); - printf("\n"); - } - - Launch_params launch_params; - // Set launch params to choose kernels - launch_params.interleaved = interleaved; - launch_params.ignore_b1opt = ignore_b1opt; - launch_params.force_unroll = force_unroll; - launch_params.use_tma = use_tma; - - // The Q matrix of size S_Q x B x H x D. - const size_t q_size = s_q * b * h * d; - // The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D. - const size_t kv_size = s_kv_padded * b * h * 2 * d; - // Allocate on the host. - float* q_h = (float*)malloc(q_size * sizeof(float)); - // Allocate on the host. - float* kv_h = (float*)malloc(kv_size * sizeof(float)); - // The size in bytes. - const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type); - // The size in bytes. - const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); - // Allocate on the device. - void* q_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes)); - // Allocate on the device. - void* kv_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes)); - - // The mask for dropout. - const size_t mask_size = s_q * b * s_kv_padded; - // Allocate on the host. - float* mask_h = (float*)malloc(mask_size * sizeof(float)); - // The size in bytes. - const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); - // Allocate on the device. - void* mask_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes)); - - // The decomposition of threads and warps for BMM1. - size_t warps_m, warps_n, warps_k; - std::tie(warps_m, warps_n, warps_k) = - get_warps(launch_params, sm, data_type, s_kv_padded, b, d_padded, v1 ? 1 : 2); - - // For multi-CTA cases, determine the size of the CTA wave. - int heads_per_wave, ctas_per_head; - get_grid_size(heads_per_wave, ctas_per_head, sm, data_type, b, s_kv_padded, h, d, - false, // disable multi-cta kernels by default - v1 ? 1 : 2); - - // The number of threads per CTA. - const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; - // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. - const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); - // The number of mmas in the N dimension. - const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); - // We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask). - assert(!v1 || mmas_n <= 4); - // The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA. - const size_t packed_mask_size = b * mmas_m * threads_per_cta; - // The size in bytes. - const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); - // Allocate on the host. - uint32_t* packed_mask_h = (uint32_t*)malloc(packed_mask_size_in_bytes); - // Allocate on the device. - void* packed_mask_d = nullptr; - - // The O matrix is packed as S_Q * B * H * D. - const size_t o_size = s_q * b * h * d; - // Allocate on the host. - float* o_h = (float*)malloc(o_size * sizeof(float)); - // The size in bytes. - const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); - // Allocate on the device. - void* o_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); - void* softmax_sum_d; - FMHA_CHECK_CUDA(cudaMalloc(&softmax_sum_d, sizeof(float) * b * s_q * h)); - FMHA_CHECK_CUDA(cudaMemset(softmax_sum_d, 0x00, sizeof(float) * b * s_q * h)); - void* softmax_max_d; - FMHA_CHECK_CUDA(cudaMalloc(&softmax_max_d, sizeof(float) * b * s_q * h)); - FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h)); - - // The size in bytes. - const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); - // Allocate on the device. - void* tmp_d = nullptr; - if (data_type != acc_type) { - FMHA_CHECK_CUDA(cudaMalloc(&tmp_d, tmp_size_in_bytes)); - } - - // Allocate the reference on the host. - float* o_ref_h = (float*)malloc(o_size * sizeof(float)); - - // The P matrix is stored as one big matrix of size S_Q x B x H x S_KV. - const size_t p_size = s_q * b * h * s_kv_padded; - // The size in bytes. - const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); - // Allocate on the device. - void* p_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes)); - - // Allocate the reference on the host. - float* p_ref_h = (float*)malloc(p_size * sizeof(float)); -#if defined(STORE_P) - // Allocate on the host. - float* p_h = (float*)malloc(p_size * sizeof(float)); -#endif // defined(STORE_P) - - // The size in bytes of the S matrix (the data type may be different from P for int8). - const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); - // Allocate on the device. - void* s_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes)); - - // Allocate the reference on the host. - float* s_ref_h = (float*)malloc(p_size * sizeof(float)); - - // Allocate on the host. - float* s_h = (float*)malloc(p_size * sizeof(float)); - // Make sure we set the seed for reproducible results. - srand(1234UL); - - // Set the Q, K and V matrices. - random_init("Q", q_h, d, s_q * b * h, d, use_1s_q, range_q, scale_q, verbose); - random_init("K", kv_h + 0 * d, d, s_kv_padded * b * h, 2 * d, use_1s_k, range_k, scale_k, - verbose); - random_init("V", kv_h + 1 * d, d, s_kv_padded * b * h, 2 * d, use_1s_v, range_v, scale_v, - verbose); - - // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. - // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - const size_t v_size = s_kv_padded * b * h * d; - // The size in bytes. - const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type); - float* vt_h = (float*)malloc(v_size * sizeof(float)); - void* vt_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes)); - for (size_t it = 0; it < v_size; it++) { - // vt is B x H x D x S_KV - size_t si = it % s_kv_padded; - size_t di = (it / s_kv_padded) % d; - size_t hi = ((it / s_kv_padded) / d) % h; - size_t bi = (((it / s_kv_padded) / d) / h) % b; - // kv is S_KV x B x H x 2 x D - size_t kv_idx = si * b * h * 2 * d + bi * h * 2 * d + hi * 2 * d + 1 * d // index V here - + di; - vt_h[it] = kv_h[kv_idx]; - } - FMHA_CHECK_CUDA(cuda_memcpy_h2d(vt_d, vt_h, v_size, data_type)); - - // // DEBUG. - // float sum = 0.f; - // for( size_t si = 0; si < s; ++si ) { - // float v = qkv_h[si*b*h*3*d + 2*d]; - // printf("V[%3d]=%8.3f\n", si, v); - // sum += v; - // } - // printf("Sum of V = %8.3f\n", sum); - // // END OF DEBUG. - - // Copy from the host to the device. - FMHA_CHECK_CUDA(cuda_memcpy_h2d(q_d, q_h, q_size, data_type)); - FMHA_CHECK_CUDA(cuda_memcpy_h2d(kv_d, kv_h, kv_size, data_type)); - - // Create the buffer of mask. - // if(verbose) {printf("Init .........: mask\n"); } - // random_init_with_zeroes_or_ones(mask_h, b*s, use_1s_mask, 1.f - dropout, verbose); - - auto const create_seqlen = [min_s, fix_s, b](int s, std::vector& seqlens, - std::vector& cu_seqlens, void** cu_seqlens_d) { - std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), [=](const uint32_t) { - if (fix_s) { - return std::min(uint32_t(s), min_s); - } - throw std::runtime_error("Not supported"); - // if( s_q == min_s ) { - // return min_s; - // } - // uint32_t s_ = s_q - min_s + 1; - // uint32_t ret = min_s + (rand() % s_); - // assert(ret <= s_q); - // return ret; - }); - - // Compute the prefix sum of the sequence lengths. - for (int it = 0; it < b; it++) { - cu_seqlens[it + 1] = cu_seqlens[it] + seqlens[it]; - } - - FMHA_CHECK_CUDA(cudaMalloc(cu_seqlens_d, sizeof(int) * cu_seqlens.size())); - FMHA_CHECK_CUDA(cudaMemcpy(*cu_seqlens_d, cu_seqlens.data(), sizeof(int) * cu_seqlens.size(), - cudaMemcpyHostToDevice)); - }; - - std::vector seqlens_q(b, 0); // randomly draw a batch of sequence lengths >= min_s - std::vector cu_seqlens_q(b + 1, 0); - // transfer to device - void* cu_seqlens_q_d; - - std::vector seqlens_kv(b, 0); // randomly draw a batch of sequence lengths >= min_s - std::vector cu_seqlens_kv(b + 1, 0); - // transfer to device - void* cu_seqlens_kv_d; - - create_seqlen(s_q, seqlens_q, cu_seqlens_q, &cu_seqlens_q_d); - int total_q = cu_seqlens_q.back(); - create_seqlen(s_kv, seqlens_kv, cu_seqlens_kv, &cu_seqlens_kv_d); - int total_kv = cu_seqlens_kv.back(); - - size_t q_packed_size = cu_seqlens_q.back() * h * d; - size_t kv_packed_size = cu_seqlens_kv.back() * h * 2 * d; - size_t q_packed_size_in_bytes = get_size_in_bytes(q_packed_size, data_type); - size_t kv_packed_size_in_bytes = get_size_in_bytes(kv_packed_size, data_type); - void* q_packed_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&q_packed_d, q_packed_size_in_bytes)); - - void* kv_packed_d = nullptr; - FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes)); - - const size_t o_packed_size = cu_seqlens_q.back() * h * d; - // Allocate on the host. - float* o_packed_h = (float*)malloc(o_packed_size * sizeof(float)); - float* o_ref_packed_h = (float*)malloc(o_packed_size * sizeof(float)); - void* o_packed_d = nullptr; - - size_t o_packed_size_in_bytes = get_size_in_bytes(o_packed_size, data_type); - FMHA_CHECK_CUDA(cudaMalloc(&o_packed_d, o_packed_size_in_bytes)); - - std::vector kv_packed_h(kv_packed_size); - extract_and_transpose_input(kv_packed_h.data(), kv_h, seqlens_kv, s_kv_padded, b, h, d, 2); - if (interleaved) { - x_vec32(true, kv_packed_h.data(), h, total_kv, 2); - } - - std::vector q_packed_h(q_packed_size); - extract_and_transpose_input(q_packed_h.data(), q_h, seqlens_q, s_q, b, h, d, 1); - if (interleaved) { - x_vec32(true, q_packed_h.data(), h, total_q, 1); - } - - // printf("%f %f\n", qkv_packed_h[0], qkv_h[0]); - FMHA_CHECK_CUDA(cuda_memcpy_h2d(q_packed_d, q_packed_h.data(), q_packed_size, data_type)); - FMHA_CHECK_CUDA(cuda_memcpy_h2d(kv_packed_d, kv_packed_h.data(), kv_packed_size, data_type)); - - for (size_t so = 0; so < s_q; ++so) { - for (size_t bi = 0; bi < b; ++bi) { - int actual_seqlen_q = seqlens_q[bi]; - int actual_seqlen_kv = seqlens_kv[bi]; - for (size_t si = 0; si < s_kv_padded; ++si) { - // Are both the query and the key inside the sequence? - bool valid = si < actual_seqlen_kv && so < actual_seqlen_q; - // The mask is stored as floats. - mask_h[so * b * s_kv_padded + bi * s_kv_padded + si] = valid ? 1.f : 0.f; - } - } - } - - // Copy the mask to the device. - FMHA_CHECK_CUDA(cuda_memcpy_h2d(mask_d, mask_h, mask_size, DATA_TYPE_INT8)); - - // Set the params. - bert::Fused_multihead_attention_params_mhca params; - set_params(params, data_type, acc_type, b, s_q, s_kv_padded, h, d, d_padded, total_kv, q_packed_d, - kv_packed_d, cu_seqlens_q_d, cu_seqlens_kv_d, o_packed_d, p_d, s_d, scale_bmm1, - scale_softmax, scale_bmm2, use_int8_scale_max); - - // Allocate barriers and locks. - void* counters_d = nullptr; - if (ctas_per_head > 1) { - size_t sz = heads_per_wave * sizeof(int); - FMHA_CHECK_CUDA(cudaMalloc((void**)&counters_d, 3 * sz)); - } - - // Allocate scratch storage for softmax. - void *max_scratch_d = nullptr, *sum_scratch_d = nullptr; - if (ctas_per_head > 1) { - size_t sz = heads_per_wave * ctas_per_head * threads_per_cta * sizeof(float); - FMHA_CHECK_CUDA(cudaMalloc((void**)&max_scratch_d, sz)); - FMHA_CHECK_CUDA(cudaMalloc((void**)&sum_scratch_d, sz)); - } - - // Allocate temporary storage for the parallel reduction. - void* o_scratch_d = nullptr; - if (ctas_per_head > 1 && data_type != DATA_TYPE_FP16) { - size_t sz = heads_per_wave * threads_per_cta * MAX_STGS_PER_LOOP * sizeof(uint4); - FMHA_CHECK_CUDA(cudaMalloc((void**)&o_scratch_d, sz)); - } - - // The number of heads computed per wave. - params.heads_per_wave = heads_per_wave; - - // Barriers for the global sync in the multi-CTA kernel(s). - params.counters = (int*)counters_d + 0 * heads_per_wave; - params.max_barriers = (int*)counters_d + 0 * heads_per_wave; - params.sum_barriers = (int*)counters_d + 1 * heads_per_wave; - params.locks = (int*)counters_d + 2 * heads_per_wave; - - // Scratch storage for softmax. - params.max_scratch_ptr = (float*)max_scratch_d; - params.sum_scratch_ptr = (float*)sum_scratch_d; - - // Scratch storage for output. - params.o_scratch_ptr = (int*)o_scratch_d; - - // Run a few warm-up kernels. - for (int ii = 0; ii < warm_up_runs; ++ii) { - run_fmhca(params, launch_params, data_type, sm, 0); - } - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - - float non_fused_elapsed = INFINITY; - if (!skip_checks) { - // Run cuBLAS. - - RefBMM bmm1(data_type_to_cuda(data_type), // a - data_type_to_cuda(data_type), // b - data_type_to_cuda(acc_type), // d - data_type_to_cublas(acc_type), // compute - data_type_to_cuda(acc_type), // scale - false, // Q - true, // K' - s_q, // m - s_kv_padded, // n - d, // k - b * h * d, // ld Q - b * h * 2 * d, // ld K - b * h * s_kv_padded, // ld P - d, // stride Q - 2 * d, // stride K - s_kv_padded, // stride P - b * h // batch count - ); - - // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. - // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - RefBMM bmm2(data_type_to_cuda(data_type), // a - data_type_to_cuda(data_type), // b - data_type_to_cuda(acc_type), // d - data_type_to_cublas(acc_type), // compute - data_type_to_cuda(acc_type), // scale - false, // S - true, // V' - s_q, // m - d, // n - s_kv_padded, // k - b * h * s_kv_padded, // ld S - s_kv_padded, // ld V - b * h * d, // ld O - s_kv_padded, // stride S - s_kv_padded * d, // stride V - d, // stride O - b * h // batch count - ); - - timer.start(); - ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, q_d, kv_d, - vt_d, // WAR pass in V' - mask_d, p_d, s_d, tmp_d, o_d, softmax_sum_d, cu_seqlens_q_d, b, s_q, s_kv_padded, - h, d, runs, warps_m, warps_n, false); - timer.stop(); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - FMHA_CHECK_CUDA(cudaDeviceSynchronize()); - non_fused_elapsed = timer.millis(); - -#if defined(STORE_P) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(p_ref_h, p_d, p_size, acc_type)); -#endif // defined(STORE_P) - -#if defined(STORE_S) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(s_ref_h, s_d, p_size, data_type)); -#endif // defined(STORE_S) - - // Read the results. - FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_ref_h, o_d, o_size, data_type)); - } - - // Fill-in p/s/o with garbage data. - FMHA_CHECK_CUDA(cudaMemset(p_d, 0xdc, p_size_in_bytes)); - FMHA_CHECK_CUDA(cudaMemset(s_d, 0xdc, s_size_in_bytes)); - FMHA_CHECK_CUDA(cudaMemset(o_d, 0xdc, o_size_in_bytes)); - - // Run the kernel. - timer.start(); - for (int ii = 0; ii < runs; ++ii) { - run_fmhca(params, launch_params, data_type, sm, 0); - } - timer.stop(); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - - FMHA_CHECK_CUDA(cudaDeviceSynchronize()); - float fused_elapsed = timer.millis(); - -#if defined(STORE_P) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(p_h, p_d, p_size, acc_type)); - printf("\nChecking .....: P = norm * K^T * Q\n"); - - // Clear the invalid region of P. - set_mat(p_ref_h, seqlens_q, seqlens_kv, s_q, b, h, s_kv_padded, 0.f, true); - set_mat(p_h, seqlens_q, seqlens_kv, s_q, b, h, s_kv_padded, 0.f, true); - - // Do the check. - check_results(p_h, p_ref_h, s_kv_padded, - cu_seqlens_q.back() /*not needed: * b -- already counted */ * h, s_kv_padded, 0.f, - true, true); -#endif // defined(STORE_P) - -#if defined(STORE_S) - FMHA_CHECK_CUDA(cuda_memcpy_d2h(s_h, s_d, p_size, data_type)); - printf("\nChecking .....: S = softmax(P)\n"); -#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - float softmax_epsilon = data_type == DATA_TYPE_FP16 ? 1e-3f : 0.f; -#else - float softmax_epsilon = 1.e-3f; -#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - - // Clear the invalid region of S. - set_mat(s_ref_h, seqlens_q, s_q, b, h, s_kv_padded, 0.f); - set_mat(s_h, seqlens_q, s_q, b, h, s_kv_padded, 0.f); - - // Do the check. - check_results(s_h, s_ref_h, s_kv_padded, cu_seqlens_q.back() * h, s_kv_padded, softmax_epsilon, - true, true); -#endif // defined(STORE_S) - - // Check the final results. - int status = -1; - if (skip_checks) { - printf("\n"); - print_results(true, false); - status = 0; - } else { - FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_packed_h, o_packed_d, o_packed_size, data_type)); - - if (interleaved) { - // revert batch-interleaved format: 3 x h/32 x total x d x 32 => total x - // h x 3 x d - x_vec32(false, o_packed_h, h, total_q, 1); - } - - extract_and_transpose_output(o_ref_packed_h, o_ref_h, seqlens_q, s_q, b, h, d); - - if (verbose) { - printf("\nChecking .....: O = V * S\n"); - } - - status = check_results(o_packed_h, o_ref_packed_h, d, cu_seqlens_q.back() * h, d, epsilon, - verbose, true); - - expand_and_transpose_output(o_h, o_packed_h, seqlens_q, s_q, b, h, d); - eval(o_ref_h, o_h, seqlens_q, b, s_q, h, d, verbose); - // printf("%f %f\n", o_packed_h[0], o_ref_h[0]); - - if (status != 0) { // if there was an error, print the config of the run - printf("v1=%d il=%d s_q=%lu s_kv=%lu b=%lu h=%lu d=%lu dtype=%s\n", v1, interleaved, s_q, - s_kv, b, h, d, data_type_to_name(data_type).c_str()); - } - - if (!verbose) { // this just prints the SUCCESS/ERROR line - print_results(true, true, status == 0); - } - } - - if (verbose) { - // Runtimes. - printf("\n"); - if (skip_checks) { - printf("Non-fused time: %.6fms\n", non_fused_elapsed / float(runs)); - } - printf("Fused time ...: %.6fms\n", fused_elapsed / float(runs)); - if (!skip_checks) { - printf("Ratio ........: %.2fx\n", non_fused_elapsed / fused_elapsed); - } - } else { - printf("Elapsed ......: %.6f (%.2fx)\n", fused_elapsed, non_fused_elapsed / fused_elapsed); - } -#if defined(DEBUG_HAS_PRINT_BUFFER) - FMHA_CHECK_CUDA( - cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32)); - - printf("\n====================\n"); - for (int it = 0; it < 16; it++) { - printf("% .4f ", print_buffer[it]); - } - printf("\n====================\n"); - - FMHA_CHECK_CUDA(cudaFree(params.print_ptr)); - -#endif - - // Release memory. - FMHA_CHECK_CUDA(cudaFree(q_d)); - FMHA_CHECK_CUDA(cudaFree(kv_d)); - FMHA_CHECK_CUDA(cudaFree(mask_d)); - FMHA_CHECK_CUDA(cudaFree(packed_mask_d)); - FMHA_CHECK_CUDA(cudaFree(p_d)); - FMHA_CHECK_CUDA(cudaFree(s_d)); - FMHA_CHECK_CUDA(cudaFree(o_d)); - FMHA_CHECK_CUDA(cudaFree(tmp_d)); - FMHA_CHECK_CUDA(cudaFree(cu_seqlens_q_d)); - FMHA_CHECK_CUDA(cudaFree(cu_seqlens_kv_d)); - FMHA_CHECK_CUDA(cudaFree(max_scratch_d)); - FMHA_CHECK_CUDA(cudaFree(sum_scratch_d)); - FMHA_CHECK_CUDA(cudaFree(o_scratch_d)); - FMHA_CHECK_CUDA(cudaFree(counters_d)); - FMHA_CHECK_CUDA(cudaFree(softmax_sum_d)); - FMHA_CHECK_CUDA(cudaFree(softmax_max_d)); - - free(q_h); - free(kv_h); - free(mask_h); - free(packed_mask_h); - free(s_h); - free(o_h); - free(o_ref_h); - - free(p_ref_h); -#if defined(STORE_P) - free(p_h); -#endif // defined(STORE_P) - free(s_ref_h); - - return status; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/fmha_v2/softmax_bf16.cu b/csrc/fmha_v2/softmax_bf16.cu deleted file mode 100644 index 5687c4d70c..0000000000 --- a/csrc/fmha_v2/softmax_bf16.cu +++ /dev/null @@ -1,21 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include "softmax_impl.h" - -void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, - s_inner, s_outer, b, h, 0.f, 0.f, softcapping_scale_bmm1, - warps_n, has_alibi); -} diff --git a/csrc/fmha_v2/softmax_fp16.cu b/csrc/fmha_v2/softmax_fp16.cu deleted file mode 100644 index 63ce2898a5..0000000000 --- a/csrc/fmha_v2/softmax_fp16.cu +++ /dev/null @@ -1,21 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include "softmax_impl.h" - -void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, - s_inner, s_outer, b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, - has_alibi); -} diff --git a/csrc/fmha_v2/softmax_fp32.cu b/csrc/fmha_v2/softmax_fp32.cu deleted file mode 100644 index deef93a4ee..0000000000 --- a/csrc/fmha_v2/softmax_fp32.cu +++ /dev/null @@ -1,21 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include "softmax_impl.h" - -void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, - s_inner, s_outer, b, h, 0.f, 0.f, softcapping_scale_bmm1, - warps_n, has_alibi); -} diff --git a/csrc/fmha_v2/softmax_fp8.cu b/csrc/fmha_v2/softmax_fp8.cu deleted file mode 100644 index e7fcd91526..0000000000 --- a/csrc/fmha_v2/softmax_fp8.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include "softmax_impl.h" - -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) { - run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, - s_inner, s_outer, b, h, 0.f, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); -} diff --git a/csrc/fmha_v2/softmax_impl.h b/csrc/fmha_v2/softmax_impl.h deleted file mode 100644 index 0d991b704c..0000000000 --- a/csrc/fmha_v2/softmax_impl.h +++ /dev/null @@ -1,1004 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include -#include - -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The number of threads per warp. -enum { THREADS_PER_WARP = 32 }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax_params { - // Output pointer. - Dst_type* dst; - // Source pointer. - Src_type const* src; - // Masks. - int8_t const* mask; - // Attention sinks (per head). - float const* attention_sinks; - // Softmax sum pointer. - float* softmax_sum; - // ALiBi - bool has_alibi; - // Dimensions of the problem. - size_t b, h; - // Precomputed constants. - size_t bhs, hs, bs; - // The scaling factors to apply when we convert to/from float. - float scale_bmm1, softcapping_scale_bmm1, scale_softmax; - // The number of reduction warps used by the fused kernel. - int warps_n; - int* cu_q_seqlens; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float to_float(uint16_t const& src, float) { - return fmha::half_to_float(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Disable warning #177-D because this function has not been used elsewhere -#pragma nv_diag_suppress 177 - -static inline __device__ float to_float(fmha::bf16_t const& src, float) { - return __bfloat162float(src); -} - -#pragma nv_diag_default 177 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Disable warning #177-D because this function has not been used elsewhere -#pragma nv_diag_suppress 177 - -static inline __device__ float to_float(fmha::e4m3_t const& src, float) { return float(src); } - -#pragma nv_diag_default 177 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float to_float(float const& src, float) { return src; } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float to_float(int const& src, float scale) { - float dst; - - // Convert from int to float. - dst = static_cast(src); - - // Scale. - dst *= scale; - - return dst; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void from_float(uint16_t& dst, float const& src, float) { - dst = fmha::float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void from_float(fmha::bf16_t& dst, float const& src, float) { - dst = fmha::float_to_bf16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ int8_t float_to_int8_rn(float x) { - uint32_t dst; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); - return reinterpret_cast(dst); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void from_float(int8_t& dst, float const& src, float scale) { - dst = float_to_int8_rn(src * scale); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void from_float(fmha::e4m3_t& dst, float const& src, float scale) { - dst = fmha::e4m3_t(src * scale); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float apply_exp_(float x, float max) { - return isinf(x) ? 0.f : __expf(x - max); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], - int warps_n, float& sum_fp32, float& max_fp32, - float const attention_sink) { -// Apply the masks. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF; - } - - // Compute the max inside the thread. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]); - } - -// Compute inside the warp. -#pragma unroll - for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { - max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask)); - } - -// Transform the elements. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32); - } - - // Compute the max inside the thread. -#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - -#pragma unroll - for (int ii = 0; ii < N; ii++) { - sum_fp32 += data_fp32[ii][0]; //+0 +64 +128 - } - - // Emulate tmp[0] + tmp[1] - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 4); - __syncwarp(); - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 1); - __syncwarp(); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 2); - __syncwarp(); - - // Emulate final reduction - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 8); - __syncwarp(); - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 16); - __syncwarp(); - -#else -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - sum_fp32 += data_fp32[ii][0]; - } - -// Compute inside the warp. -#pragma unroll - for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask); - } -#endif - - // // DEBUG. - // if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { - // printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0].x, sum_fp32); - // } - - // Fix the sum if needed. - if (sum_fp32 == 0.f || sum_fp32 != sum_fp32) { - sum_fp32 = 1.f; - } - - // Normalize. - float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] *= inv_sum_fp32; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], - int warps_n, float& sum_fp32, float& max_fp32, - float const attention_sink) { -// Apply the masks. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF; - data_fp32[ii][1] = mask[ii][1] ? data_fp32[ii][1] : -HUGE_VALF; - } - - // Compute the max inside the thread. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]); - max_fp32 = fmaxf(max_fp32, data_fp32[ii][1]); - } - -// Compute inside the warp. -#pragma unroll - for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { - max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask)); - } - -// // DEBUG. -// if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { -// printf("elt=%12.8f max_fp32=%12.8f\n", data_fp32[0][0], max_fp32); -// } -// // END OF DEBUG. - -// Transform the elements. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32); - data_fp32[ii][1] = apply_exp_(data_fp32[ii][1], max_fp32); - } - - // Compute the max inside the thread. - // float sum_fp32 = 0.f; -#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - if (warps_n == 1) { - // TODO not sure if we can improve this on the gmma side without using additional regs. - - // this is intentionally o(n) instead of o(log n) - // lanes 0 and 1 here represent the first quad. - - // need to account for offset of l0 when addressing absolute lanes. - int const ti = threadIdx.x % 4; - float tmp = 0.f; - - for (int ni = 0; ni < N; ni++) { - float x = data_fp32[ni][0] + data_fp32[ni][1]; - tmp += x; - - for (int it = 1; it < 8; it++) { - tmp += __shfl_sync(uint32_t(-1), x, 4 * it + ti); - __syncwarp(); - } - } - - // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp += __shfl_xor_sync(uint32_t(-1), tmp, 1); - __syncwarp(); - // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp += __shfl_xor_sync(uint32_t(-1), tmp, 2); - __syncwarp(); - sum_fp32 = __shfl_sync(uint32_t(-1), tmp, 0); - } else if (warps_n == 8) { - // Accumulate warp 0 and warp 4 - float tmp[2] = {0.f, 0.f}; -#pragma unroll - for (int ii = 0; ii < N; ii += 2) { - tmp[0] += data_fp32[ii + 0][0]; - tmp[0] += data_fp32[ii + 0][1]; - tmp[1] += data_fp32[ii + 1][0]; - tmp[1] += data_fp32[ii + 1][1]; - } - - // Emulate tmp[0] + tmp[1] - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 4); - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 1); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); - // Emulate final reduction - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 8); - - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); - sum_fp32 = tmp[0] + tmp[1]; - - sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0); - } else { -#pragma unroll - for (int ii = 0; ii < N; ii++) { - sum_fp32 += data_fp32[ii][0] + data_fp32[ii][1]; - } - - // Emulate tmp[0] + tmp[1] - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 4); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 1); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 2); - // Emulate final reduction - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 8); - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 16); - - sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0); - } - -#else -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - sum_fp32 += data_fp32[ii][0]; - sum_fp32 += data_fp32[ii][1]; - } - -// Compute inside the warp. -#pragma unroll - for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask); - } -#endif - - // // DEBUG. - // if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { - // printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0][0], sum_fp32); - // } - - // Fix the sum if needed. - if (sum_fp32 == 0.f || sum_fp32 != sum_fp32) { - sum_fp32 = 1.f; - } - - // Normalize. - float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] *= inv_sum_fp32; - data_fp32[ii][1] *= inv_sum_fp32; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], - int warps_n, float& sum_fp32, float& max_fp32, - float const attention_sink) { -// Apply the masks. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF; - data_fp32[ii][1] = mask[ii][1] ? data_fp32[ii][1] : -HUGE_VALF; - data_fp32[ii][2] = mask[ii][2] ? data_fp32[ii][2] : -HUGE_VALF; - data_fp32[ii][3] = mask[ii][3] ? data_fp32[ii][3] : -HUGE_VALF; - } - - // Compute the max inside the thread. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]); - max_fp32 = fmaxf(max_fp32, data_fp32[ii][1]); - max_fp32 = fmaxf(max_fp32, data_fp32[ii][2]); - max_fp32 = fmaxf(max_fp32, data_fp32[ii][3]); - } - -// Compute inside the warp. -#pragma unroll - for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { - max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask)); - } - -// // DEBUG. -// if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { -// printf("elt=%12.8f max_fp32=%12.8f\n", data_fp32[0][0], max_fp32); -// } - -// Transform the elements. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32); - data_fp32[ii][1] = apply_exp_(data_fp32[ii][1], max_fp32); - data_fp32[ii][2] = apply_exp_(data_fp32[ii][2], max_fp32); - data_fp32[ii][3] = apply_exp_(data_fp32[ii][3], max_fp32); - } - - // Compute the max inside the thread. - // float sum_fp32 = 0.f; - - // TODO needs refactoring... - -#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) - // Within a thread it should correspond to the operation done in the tmp[0]/[1] loop. - - if (warps_n == 1) { // E.g. 4x1: 4 threads iterate over all cores. - // TODO not sure if we can improve this on the gmma side without using additional regs. - - // this is intentionally o(n) instead of o(log n) - // lanes 0 and 1 here represent the first quad. - - // need to account for offset of l0 when addressing absolute lanes. - int const ti = threadIdx.x % 2; - float tmp[2] = {0.f, 0.f}; - - for (int ni = 0; ni < N; ni++) { - // +1 - float x = data_fp32[ni][0] + data_fp32[ni][1]; - float y = data_fp32[ni][2] + data_fp32[ni][3]; - tmp[0] += x; - tmp[1] += y; - - for (int it = 1; it < 16; it++) { - tmp[0] += __shfl_sync(uint32_t(-1), x, 2 * it + ti); - __syncwarp(); - tmp[1] += __shfl_sync(uint32_t(-1), y, 2 * it + ti); - __syncwarp(); - } - } - - // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += tmp[1]; - // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0); - } else { - // SEQLEN == 128. - if (N == 1) { - float tmp[2] = {0.f, 0.f}; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 700 // GV100 - // The thread local reduction. - tmp[0] += data_fp32[0][0]; - tmp[0] += data_fp32[0][1]; - tmp[0] += data_fp32[0][2]; - tmp[0] += data_fp32[0][3]; - - // Add threads 0 and 2. Inside a thread in the impl. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - // Add threads 0 and 8. Inside the thread. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); - // Add threads 0 and 16. Inside the thread. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - - // Add threads 0 and 1. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - - // Add threads 0 and 4. Inter-warp in the code. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); -#else - if (warps_n == 2) { // 2x2 - tmp[0] += data_fp32[0][0] + data_fp32[0][1]; - tmp[1] += data_fp32[0][2] + data_fp32[0][3]; - - // Emulate a_01 += a_23... - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); - __syncwarp(); - - // Emulate a_01 += a_45... - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 8); - __syncwarp(); - - // Emulate a_01 += a_89... - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); - __syncwarp(); - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += tmp[1]; - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - - // Emulate the final reduction in smem. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); - } else { // 1x4 - tmp[0] += data_fp32[0][0] + data_fp32[0][1]; - tmp[1] += data_fp32[0][2] + data_fp32[0][3]; - - // Add +64. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); - __syncwarp(); - - // T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - // T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); - __syncwarp(); - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += tmp[1]; - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 4); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 8); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); - } - -#endif // ! GV100 - - // Don't forget to put the value in sum_fp32 :) - // sum_fp32 = tmp[0]; - sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0); - - // SEQLEN == 256 - compare with 1x4. - } else if (N == 2 || N == 8) { -#pragma unroll - for (int step = 0; step < N; step += 2) { - float tmp[2] = {0.f, 0.f}; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 700 // GV100 - - // The thread local reduction. - tmp[0] += data_fp32[step + 0][0]; - tmp[0] += data_fp32[step + 0][1]; - tmp[0] += data_fp32[step + 0][2]; - tmp[0] += data_fp32[step + 0][3]; - - tmp[1] += data_fp32[step + 1][0]; - tmp[1] += data_fp32[step + 1][1]; - tmp[1] += data_fp32[step + 1][2]; - tmp[1] += data_fp32[step + 1][3]; - - // Sum offset 0 and 128 (and so on). - tmp[0] += tmp[1]; - - // Add threads 0 and 2. Inside a thread in the impl. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - // Add threads 0 and 16. Inside the thread. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - - // Add threads 0 and 1. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - - // Add threads 0 and 4. Inter-warp in the code. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); - // Add threads 0 and 8. Inter-warp in the code. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); -#else - // 0. - tmp[0] += data_fp32[step + 0][0] + data_fp32[step + 0][1]; - tmp[1] += data_fp32[step + 0][2] + data_fp32[step + 0][3]; - - // Add +64. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); - __syncwarp(); - - // Add +128 but use temp storage due to the next round of shfl. - float xy = data_fp32[step + 1][0] + data_fp32[step + 1][1]; - float zw = data_fp32[step + 1][2] + data_fp32[step + 1][3]; - - // Add +128. - tmp[0] += xy; - tmp[1] += zw; - - // Add +192. - tmp[0] += __shfl_xor_sync(uint32_t(-1), xy, 16); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), zw, 16); - __syncwarp(); - - // T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - // T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); - __syncwarp(); - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += tmp[1]; - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 4); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 8); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); -#endif // ! GV100 - - // Don't forget to put the value in sum_fp32 :) - sum_fp32 += tmp[0]; - } - // Emulate taking warp results from position 0, 16, 32, 48, etc. - sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0); - - // SEQLEN == 384. - } else if (N == 3) { - float tmp[2] = {0.f, 0.f}; - -// The reduction inside the thread. -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - tmp[0] += data_fp32[ii][0]; - tmp[0] += data_fp32[ii][1]; - tmp[1] += data_fp32[ii][2]; - tmp[1] += data_fp32[ii][3]; - } - - // Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); - __syncwarp(); - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += tmp[1]; - - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - - // Emulate the final summation. - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - - // Don't forget to put the value in sum_fp32 :) - sum_fp32 += tmp[0]; - // SEQLEN == 512 - compare with 1x8. - } else if (N >= 4) { - // Emulate thread local - float tmp[2] = {0.f, 0.f}; // T0, T1 -#pragma unroll - for (int step = 0; step < N; step++) { - tmp[0] += data_fp32[step][0]; // + 0 - tmp[0] += data_fp32[step][1]; // + 1 - tmp[1] += data_fp32[step][2]; // + 2 - tmp[1] += data_fp32[step][3]; // + 3 - } - - // T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); - __syncwarp(); - // T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; - tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); - __syncwarp(); - - // Emulate intra-thread - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); - tmp[0] += tmp[1]; - // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); - __syncwarp(); - - // Emulate inter-thread - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); - __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); - __syncwarp(); - tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); - __syncwarp(); - - // Don't forget to put the value in sum_fp32 :) - // sum_fp32 = tmp[0]; - - // Emulate taking warp results from position 0, 16, 32, 48, etc. - sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0); - // Not supported. - } else { - assert(false); - } - } // warps_n == 1 -#else -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - sum_fp32 += data_fp32[ii][0]; - sum_fp32 += data_fp32[ii][1]; - sum_fp32 += data_fp32[ii][2]; - sum_fp32 += data_fp32[ii][3]; - } - -// Compute inside the warp. -#pragma unroll - for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { - sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask); - } -#endif - - // // DEBUG. - // if( blockIdx.x == 0 && threadIdx.y == 0 && threadIdx.x == 0 ) { - // printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0][0], sum_fp32); - // } - // // END OF DEBUG. - - // Fix the sum if needed. - if (sum_fp32 == 0.f || sum_fp32 != sum_fp32) { - sum_fp32 = 1.f; - } - - // Normalize. - float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); -#pragma unroll - for (int ii = 0; ii < N; ++ii) { - data_fp32[ii][0] *= inv_sum_fp32; - data_fp32[ii][1] *= inv_sum_fp32; - data_fp32[ii][2] *= inv_sum_fp32; - data_fp32[ii][3] *= inv_sum_fp32; - } -} - -template -struct VecX { - using Type = typename fmha::Uint_from_size_in_bytes::Type; - static_assert(sizeof(Type) == X * sizeof(Data_type)); - - union Alias { - Type raw; - Data_type elt[X]; - }; - - static __device__ inline void to_floatX(float (&dst)[X], Type const& src, float const scale, - float const attn_logit_softcapping_scale) { - Alias tmp; - tmp.raw = src; -#pragma unroll - for (int it = 0; it < X; it++) { - dst[it] = to_float(tmp.elt[it], scale); - if (attn_logit_softcapping_scale != 0.f) { - dst[it] = - attn_logit_softcapping_scale * fmha::__tanhf(dst[it] / attn_logit_softcapping_scale); - } - } - } - - static __device__ inline void from_floatX(Type& dst, float const (&src)[X], float const scale) { - Alias tmp; -#pragma unroll - for (int it = 0; it < X; it++) { - from_float(tmp.elt[it], src[it], scale); - } - dst = tmp.raw; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float get_alibi_head_scaling_factor(int const head_id, int const num_heads) { - // Round down to power of 2 - int const num_heads_pow2 = (1u << (31 - __clz(num_heads))); - if (head_id < num_heads_pow2) { - return exp2f((head_id + 1) * -8.0f / num_heads_pow2); - } else { - float const adjusted_head_id = 2 * (head_id - num_heads_pow2) + 1; - return exp2f(adjusted_head_id * -4.0f / num_heads_pow2); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -static __global__ void softmax_kernel(Softmax_params params) { - // By default, use LDG.64 for the loads and STG.64 for the stores. - enum { ELEMENTS_PER_LDG = X, ELEMENTS_PER_STG = X }; - - // The number of Vec_type per thread. - enum { VECs_PER_THREAD = SEQLEN / THREADS_PER_WARP / ELEMENTS_PER_LDG }; - - // DEBUG. - static_assert(VECs_PER_THREAD * THREADS_PER_WARP * ELEMENTS_PER_LDG == SEQLEN, ""); - // END OF DEBUG. - - using VecO = VecX; - using VecI = VecX; - using VecM = VecX; - // The vector types. - using DstX_type = typename VecO::Type; - using SrcX_type = typename VecI::Type; - - // Make sure the sizes match our expectations. - static_assert(sizeof(DstX_type) == X * sizeof(Dst_type)); - static_assert(sizeof(SrcX_type) == X * sizeof(Src_type)); - - // The type of the mask. - using MaskX_type = typename VecM::Type; - - // One warp per sequence. - size_t hi = blockIdx.y * WARPS_PER_CTA + threadIdx.y; - size_t bi = blockIdx.z; - size_t si = blockIdx.x; - - // The data offset. Layout is S * B * H * S. - size_t src_offset = - si * params.bhs + bi * params.hs + hi * SEQLEN + threadIdx.x * ELEMENTS_PER_LDG; - - // Load the input elements. - SrcX_type const* src_ptr = reinterpret_cast(¶ms.src[src_offset]); - SrcX_type data_src[VECs_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { - if (hi < params.h) { - data_src[ii] = src_ptr[ii * THREADS_PER_WARP]; - } - } - - // The mask offset. Layout is S * B * S. - size_t mask_offset = si * params.bs + bi * SEQLEN + threadIdx.x * ELEMENTS_PER_LDG; - - // Load the masks. - MaskX_type const* mask_ptr = reinterpret_cast(¶ms.mask[mask_offset]); - MaskX_type mask[VECs_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { - mask[ii] = mask_ptr[ii * THREADS_PER_WARP]; - } - - // Convert the data to float. - float data_fp32[VECs_PER_THREAD][X]; - int8_t mask_[VECs_PER_THREAD][X]; -#pragma unroll - for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { - VecI::to_floatX(data_fp32[ii], data_src[ii], params.scale_bmm1, params.softcapping_scale_bmm1); - - typename VecM::Alias tmp; - tmp.raw = mask[ii]; -#pragma unroll - for (int it = 0; it < X; it++) { - mask_[ii][it] = tmp.elt[it]; - } - } - - if (params.has_alibi) { - float const alibi_factor = get_alibi_head_scaling_factor(hi, params.h); -#pragma unroll - for (int ii = 0; ii < VECs_PER_THREAD; ii++) { -#pragma unroll - for (int jj = 0; jj < X; jj++) { - int col = ii * THREADS_PER_WARP * X + threadIdx.x * X + jj; - data_fp32[ii][jj] += alibi_factor * col; - } - } - } - - // The attention sink value. - float attention_sink = -FLT_MAX; - if (params.attention_sinks != nullptr) { - attention_sink = params.attention_sinks[hi]; - } - - // Do the reduction. - float sum_fp32 = 0.f; - float max_fp32 = -HUGE_VALF; - reduce(data_fp32, mask_, params.warps_n, sum_fp32, max_fp32, attention_sink); - if (threadIdx.x == 0) { - int sum_s = params.cu_q_seqlens[bi]; - // [B, S, H, 2] {max, sum} float - if (hi < params.h) { - params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2] = max_fp32; - params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2 + 1] = sum_fp32; - } - } - // Reconvert to half. - DstX_type data_dst[VECs_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { - VecO::from_floatX(data_dst[ii], data_fp32[ii], params.scale_softmax); - } - - // Store the output elements. - DstX_type* dst_ptr = reinterpret_cast(¶ms.dst[src_offset]); -#pragma unroll - for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { - if (hi < params.h) { - dst_ptr[ii * THREADS_PER_WARP] = data_dst[ii]; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum, void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, - float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) { - printf("Softmax impl\n"); - Softmax_params params; - memset(¶ms, 0, sizeof(params)); - - // The different pointers. - params.dst = reinterpret_cast(dst); - params.src = reinterpret_cast(src); - params.softmax_sum = reinterpret_cast(softmax_sum); - params.cu_q_seqlens = reinterpret_cast(cu_q_seqlens); - params.mask = reinterpret_cast(mask); - params.attention_sinks = reinterpret_cast(attention_sinks); - params.has_alibi = has_alibi; - - // The dimensions and precomputed values. - params.b = b; - params.h = h; - params.bhs = b * h * s_inner; - params.hs = h * s_inner; - params.bs = b * s_inner; - - // The scaling factors for the int8 version to convert to/from float. - params.scale_bmm1 = scale_bmm1; - params.softcapping_scale_bmm1 = softcapping_scale_bmm1; - params.scale_softmax = scale_softmax; - // The number of warps_n used to identify the reduction strategy. - params.warps_n = warps_n; - - // Compute the grid size. - enum { WARPS_PER_CTA = 4 }; - - dim3 grid(s_outer, (h + WARPS_PER_CTA - 1) / WARPS_PER_CTA, b); - dim3 threads_per_cta(THREADS_PER_WARP, WARPS_PER_CTA); - - // Launch the kernel. - if (s_inner == 32) { - softmax_kernel<<>>(params); - } else if (s_inner == 64) { - softmax_kernel<<>>(params); - } else if (s_inner == 96) { - softmax_kernel<<>>(params); - } else if (s_inner == 128) { - softmax_kernel<<>>(params); - } else if (s_inner == 192) { - softmax_kernel<<>>(params); - } else if (s_inner == 256) { - softmax_kernel<<>>(params); - } else if (s_inner == 384) { - softmax_kernel<<>>(params); - } else if (s_inner == 512) { - softmax_kernel<<>>(params); - } else if (s_inner == 1024) { - softmax_kernel<<>>(params); - } else if (s_inner == 2048) { - softmax_kernel<<>>(params); - } else if (s_inner == 4096) { - softmax_kernel<<>>(params); - } else if (s_inner == 8192) { - softmax_kernel<<>>(params); - } else if (s_inner == 16384) { - softmax_kernel<<>>(params); - } else if (s_inner == 32768) { - softmax_kernel<<>>(params); - } else if (s_inner == 65536) { - softmax_kernel<<>>(params); - } else { - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/fmha_v2/softmax_int8.cu b/csrc/fmha_v2/softmax_int8.cu deleted file mode 100644 index a0146338e0..0000000000 --- a/csrc/fmha_v2/softmax_int8.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement - * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. - */ - -#include "softmax_impl.h" - -void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, - void* softmax_sum_d, void* cu_q_seqlens_d, int s_inner, int s_outer, int b, - int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, - int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, - s_inner, s_outer, b, h, scale_bmm1, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); -} diff --git a/csrc/fmha_v2/templates/fa_kernel.jinja b/csrc/fmha_v2/templates/fa_kernel.jinja new file mode 100644 index 0000000000..715d5ea85c --- /dev/null +++ b/csrc/fmha_v2/templates/fa_kernel.jinja @@ -0,0 +1,508 @@ +{{ copyright }} + +//We can disable the FADD trick for archs with F2IP +{% if disable_fadd_trick %} // disable_fadd_trick +#ifdef USE_I2F_EMULATION_TRICK +#undef USE_I2F_EMULATION_TRICK +#endif // USE_I2F_EMULATION_TRICK + +#ifdef USE_F2I_EMULATION_TRICK +#undef USE_F2I_EMULATION_TRICK +#endif // USE_F2I_EMULATION_TRICK +{% endif %} // disable_fadd_trick + +#include +#include + +#if CUDA_VERSION >= {{ min_cuda_version }} + + +{% if not use_multi_cta %} // !use_multi_cta +{% if kernel_variant == "flash_attention" %} +#include +{% else %} +#include +{% endif %} +{% endif %} // !use_multi_cta + +{% if not use_multi_cta and has_noloop %} // !use_multi_cta && has_noloop +{% if kernel_variant == "flash_attention" %} +#include +{% if reload_q %} // reload_q: D > CTA_P_TILE_K, need tiled noloop kernel +#include +{% endif %} // reload_q +{% else %} +#include +{% endif %} +{% endif %} // !use_multi_cta && has_noloop + +{% if cross_mha %} // cross_mha +{% if has_noloop %} // has_noloop +#include +{% endif %} // has_noloop +#include +{% endif %} // cross_mha + +{% if use_multi_cta %} // use_multi_cta +#include +{% endif %} + +using Attention_mask_type = fmha::Attention_mask_type; +using Launch_params = bert::Fused_multihead_attention_launch_params; + +{% if not cross_mha and not has_noloop %} // !cross_mha && !has_noloop (looped kernel) +using Kernel_traits = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}>; + +using Kernel_traits_causal = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}, + /*causal mask*/ 3>; +{% endif %} // !cross_mha && !has_noloop + +{% if not use_multi_cta and not cross_mha and not has_noloop %} // !use_multi_cta && !cross_mha && !has_noloop + +extern "C" +__global__ +void {{ kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} + +extern "C" +__global__ +void {{ causal_kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + dim3 grid(params.h, params.b); + {{ causal_kernel_name }}<<>>(params); + } else { + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + dim3 grid(params.h, params.b); + {{ kernel_name }}<<>>(params); + } +} + +{% endif %} // !use_multi_cta && !cross_mha && !has_noloop + +{% if has_noloop and not cross_mha %} // has_noloop && !cross_mha (looped launcher stub — flash attention uses noloop only) + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + assert(false && "Flash attention uses noloop launchers"); +} + +{% endif %} // has_noloop && !cross_mha + +{% if not use_multi_cta and has_noloop and not cross_mha %} // !use_multi_cta && has_noloop && !cross_mha + +using Kernel_traits_nl = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + 1, + {{ warps_m }} * {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }} | 0x200 /* no_loop flag */ >; + +static_assert(Kernel_traits_nl::CTAS_PER_HEAD == 1, ""); + +using Kernel_traits_nl_causal = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + 1, + {{ warps_m }} * {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }} | 0x200 /* no_loop flag */, + /*causal mask*/ 3>; + +static_assert(Kernel_traits_nl_causal::CTAS_PER_HEAD == 1, ""); +static_assert(Kernel_traits_nl_causal::MASK_VERSION == 3, ""); + +extern "C" +__global__ +void {{ kernel_name }}_nl({{ params_type }} params){ +{% if kernel_variant == "flash_attention" %} + fused_multihead_attention::device_flash_attention_nl(params); +{% else %} + fused_multihead_attention::device_1xN_nl(params); +{% endif %} +} + +extern "C" +__global__ +void {{ causal_kernel_name }}_nl({{ params_type }} params){ +{% if kernel_variant == "flash_attention" %} + fused_multihead_attention::device_flash_attention_nl(params); +{% else %} + fused_multihead_attention::device_1xN_nl(params); +{% endif %} +} + +void {{ launcher_name }}_nl( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + // Runtime loop_iters: flash attention supports variable sequence lengths. + int loop_iters = ( params.s + {{ noloop_step }} - 1 ) / {{ noloop_step }}; + dim3 grid(loop_iters, params.h, params.b); // better locality + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + constexpr int smem_size = Kernel_traits_nl_causal::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ causal_kernel_name }}_nl<<>>(params); + } else { + constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ kernel_name }}_nl<<>>(params); + } +} + +{% endif %} // !use_multi_cta && has_noloop && !cross_mha + +{% if not use_multi_cta and tiled and not cross_mha %} // !use_multi_cta && tiled && !cross_mha +{% if reload_q %} // reload_q +// When D > CTA_P_TILE_K (RELOAD_Q=true), use the tiled noloop kernel which iterates +// over the K dimension in BMM1. This kernel requires ldgsts (cp.async) and +// WARPS_N=1 (no shared memory for softmax), so we define separate Kernel_traits +// with WARPS_M={{ warps_m }}, WARPS_N={{ warps_n }} instead of the nl layout (1, total). + +using Kernel_traits_nl_tiled = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }} | 0x200 /* no_loop flag */ >; + +static_assert(Kernel_traits_nl_tiled::CTAS_PER_HEAD == 1, ""); + +using Kernel_traits_nl_tiled_causal = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }} | 0x200 /* no_loop flag */, + /*causal mask*/ 3>; + +static_assert(Kernel_traits_nl_tiled_causal::CTAS_PER_HEAD == 1, ""); +static_assert(Kernel_traits_nl_tiled_causal::MASK_VERSION == 3, ""); + +{% else %} // !reload_q +// Tiled (granular) variant: uses the same Kernel_traits_nl with USE_GRANULAR_TILING +// in the flags for correct tile decomposition (CTA_O_TILE_K computed from K_PER_MMA +// instead of S, which is 0 for flash attention). Routes to the regular noloop kernel +// which supports granular tiling without requiring ldgsts. +{% endif %} + +extern "C" +__global__ +void {{ kernel_name }}_nl_tiled({{ params_type }} params){ +{% if kernel_variant == "flash_attention" %} +{% if reload_q %} + fused_multihead_attention::device_flash_attention_nl_tiled(params); +{% else %} + fused_multihead_attention::device_flash_attention_nl(params); +{% endif %} +{% else %} + fused_multihead_attention::device_1xN_nl(params); +{% endif %} +} + +extern "C" +__global__ +void {{ causal_kernel_name }}_nl_tiled({{ params_type }} params){ +{% if kernel_variant == "flash_attention" %} +{% if reload_q %} + fused_multihead_attention::device_flash_attention_nl_tiled(params); +{% else %} + fused_multihead_attention::device_flash_attention_nl(params); +{% endif %} +{% else %} + fused_multihead_attention::device_1xN_nl(params); +{% endif %} +} + +void {{ launcher_name }}_nl_tiled( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + // Runtime loop_iters: flash attention supports variable sequence lengths. + int loop_iters = ( params.s + {{ noloop_step }} - 1 ) / {{ noloop_step }}; + dim3 grid(loop_iters, params.h, params.b); // better locality +{% if reload_q %} + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + constexpr int smem_size = Kernel_traits_nl_tiled_causal::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}_nl_tiled, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ causal_kernel_name }}_nl_tiled<<>>(params); + } else { + constexpr int smem_size = Kernel_traits_nl_tiled::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl_tiled, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ kernel_name }}_nl_tiled<<>>(params); + } +{% else %} + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + constexpr int smem_size = Kernel_traits_nl_causal::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}_nl_tiled, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ causal_kernel_name }}_nl_tiled<<>>(params); + } else { + constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl_tiled, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ kernel_name }}_nl_tiled<<>>(params); + } +{% endif %} +} + +{% endif %} // !use_multi_cta && tiled && !cross_mha + +{% if cross_mha %} // cross_mha +{% if not use_multi_cta and has_noloop %} // !use_multi_cta && has_noloop + +using Kernel_traits_nl = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + 1, + {{ warps_m }} * {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}>; + +static_assert(Kernel_traits_nl::CTAS_PER_HEAD == 1, ""); + +extern "C" +__global__ +void {{ kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_mhca_1xN_nl(params); +} + +void {{ launcher_name }}_nl( + const {{ params_type }} ¶ms, + // const Launch_params &launch_params, // TODO + cudaStream_t stream){ + + constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + const int loop_iters = (params.s_q + {{ noloop_step }}-1) / {{ noloop_step }}; + // if (loop_iters * {{ noloop_step }} != params.s_q) { + // throw std::runtime_error("Incorrect seq len -- loop_iters * noloop_step != params.s_q"); + // } + assert(loop_iters * {{ noloop_step }} >= params.s_q); + dim3 grid(params.h, params.b, loop_iters); + {{ kernel_name }}_nl<<>>(params); +} + +{% endif %} // !use_multi_cta && has_noloop + +{% if not use_multi_cta %} // !use_multi_cta + +using Kernel_traits = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}>; + +extern "C" +__global__ +void {{ kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_mhca_1xN(params); +} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + // const Launch_params &launch_params, // TODO + cudaStream_t stream){ + + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + dim3 grid(params.h, params.b); + {{ kernel_name }}<<>>(params); +} + +{% endif %} // !use_multi_cta + +{% endif %} // cross_mha + +{% if use_multi_cta %} // use_multi_cta + +// If that assert gets triggered - increase the value of MAX_STGS_PER_LOOP in "setup.py". +static_assert(Kernel_traits::Gmem_tile_o::STGS_PER_LOOP <= {{ MAX_STGS_PER_LOOP }}, ""); + +extern "C" +__global__ +void {{ kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_multi_cta(params); +} + +extern "C" +__global__ +void {{ causal_kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_multi_cta(params); +} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + assert(params.heads_per_wave != 0 && "Heads per wave is not set, but multi cta is requested"); + + // Clear the barriers and locks. + cudaMemsetAsync(params.counters, 0, 3*params.heads_per_wave*sizeof(int), stream); + + // We may use more than 48kB of shared memory. + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + constexpr int smem_size = Kernel_traits_causal::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + // Launch one wave. + dim3 grid(Kernel_traits_causal::CTAS_PER_HEAD, params.heads_per_wave), block(Kernel_traits_causal::THREADS); + void *params_ = (void*) ¶ms; + FMHA_CHECK_CUDA(cudaLaunchCooperativeKernel((void*) &{{ causal_kernel_name }}, grid, block, (void**) ¶ms_, smem_size, stream)); + } else { + constexpr size_t smem_size = Kernel_traits::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + // Launch one wave. + dim3 grid(Kernel_traits::CTAS_PER_HEAD, params.heads_per_wave), block(Kernel_traits::THREADS); + void *params_ = (void*) ¶ms; + FMHA_CHECK_CUDA(cudaLaunchCooperativeKernel((void*) &{{ kernel_name }}, grid, block, (void**) ¶ms_, smem_size, stream)); + } +} + +{% endif %} // use_multi_cta + +void {{ launcher_name }}_get_max_heads_per_wave(int *heads_per_wave) { +{% if use_multi_cta %} // use_multi_cta + // Determine the number of SMs and CTAs. + int dev; + cudaGetDevice(&dev); + cudaDeviceProp props; + FMHA_CHECK_CUDA(cudaGetDeviceProperties(&props, dev)); + + // The number of CTAs per SM. + constexpr size_t smem_size = Kernel_traits::BYTES_PER_SMEM; + int ctas_per_sm; + FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, + &{{ kernel_name }}, + Kernel_traits::THREADS, + smem_size)); + + // The number of heads per wave. + *heads_per_wave = props.multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_HEAD; +{% else %} // use_multi_cta + *heads_per_wave = 0; +{% endif %} // use_multi_cta +} + +#else // CUDA_VERSION >= {{ min_cuda_version }} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + assert(false && "Unsupported CUDA version"); +} + +{% if has_noloop %} // has_noloop + +void {{ launcher_name }}_nl( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + assert(false && "Unsupported CUDA version"); +} + +{% endif %} // has_noloop + +#endif // CUDA_VERSION >= {{ min_cuda_version }} diff --git a/csrc/fmha_v2/templates/kernel.jinja b/csrc/fmha_v2/templates/kernel.jinja new file mode 100644 index 0000000000..f0a29e9923 --- /dev/null +++ b/csrc/fmha_v2/templates/kernel.jinja @@ -0,0 +1,360 @@ +{{ copyright }} + +//We can disable the FADD trick for archs with F2IP +{% if disable_fadd_trick %} // disable_fadd_trick +#ifdef USE_I2F_EMULATION_TRICK +#undef USE_I2F_EMULATION_TRICK +#endif // USE_I2F_EMULATION_TRICK + +#ifdef USE_F2I_EMULATION_TRICK +#undef USE_F2I_EMULATION_TRICK +#endif // USE_F2I_EMULATION_TRICK +{% endif %} // disable_fadd_trick + +#include +#include + +#if CUDA_VERSION >= {{ min_cuda_version }} + + +{% if not use_multi_cta %} // !use_multi_cta +#include +{% endif %} // !use_multi_cta + +{% if not use_multi_cta and has_noloop %} // !use_multi_cta && has_noloop +#include +{% endif %} // !use_multi_cta && has_noloop + +{% if cross_mha %} // cross_mha +{% if has_noloop %} // has_noloop +#include +{% endif %} // has_noloop +#include +{% endif %} // cross_mha + +{% if use_multi_cta %} // use_multi_cta +#include +{% endif %} + +using Attention_mask_type = fmha::Attention_mask_type; +using Launch_params = bert::Fused_multihead_attention_launch_params; + +{% if not cross_mha %} // !cross_mha +using Kernel_traits = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}>; + +using Kernel_traits_causal = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}, + /*causal mask*/ 3>; +{% endif %} // Not cross attention + +{% if not use_multi_cta and not cross_mha %} // !use_multi_cta && !cross_mha + +extern "C" +__global__ +void {{ kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} + +extern "C" +__global__ +void {{ causal_kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + dim3 grid(params.h, params.b); + {{ causal_kernel_name }}<<>>(params); + } else { + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + dim3 grid(params.h, params.b); + {{ kernel_name }}<<>>(params); + } +} + +{% endif %} // !use_multi_cta && !cross_mha + +{% if not use_multi_cta and has_noloop and not cross_mha %} // !use_multi_cta && has_noloop && !cross_mha + +using Kernel_traits_nl = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + 1, + {{ warps_m }} * {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }} | 0x200 /* no_loop flag */ >; + +static_assert(Kernel_traits_nl::CTAS_PER_HEAD == 1, ""); + +using Kernel_traits_nl_causal = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + 1, + {{ warps_m }} * {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }} | 0x200 /* no_loop flag */, + /*causal mask*/ 3>; + +static_assert(Kernel_traits_nl_causal::CTAS_PER_HEAD == 1, ""); +static_assert(Kernel_traits_nl_causal::MASK_VERSION == 3, ""); + +extern "C" +__global__ +void {{ kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_1xN_nl(params); +} + +extern "C" +__global__ +void {{ causal_kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_1xN_nl(params); +} + +void {{ launcher_name }}_nl( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + + constexpr int loop_iters = ({{ seq_len }} + {{ noloop_step }}-1) / {{ noloop_step }}; + static_assert(loop_iters * {{ noloop_step }} >= {{ seq_len }}, ""); + dim3 grid(params.h, params.b, loop_iters); + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + constexpr int smem_size = Kernel_traits_nl_causal::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ causal_kernel_name }}_nl<<>>(params); + } else { + constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ kernel_name }}_nl<<>>(params); + } +} + +{% endif %} // !use_multi_cta && has_noloop && !cross_mha + +{% if cross_mha %} // cross_mha +{% if not use_multi_cta and has_noloop %} // !use_multi_cta && has_noloop + +using Kernel_traits_nl = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ noloop_step }}, + 1, + {{ warps_m }} * {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}>; + +static_assert(Kernel_traits_nl::CTAS_PER_HEAD == 1, ""); + +extern "C" +__global__ +void {{ kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_mhca_1xN_nl(params); +} + +void {{ launcher_name }}_nl( + const {{ params_type }} ¶ms, + // const Launch_params &launch_params, // TODO + cudaStream_t stream){ + + constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + const int loop_iters = (params.s_q + {{ noloop_step }}-1) / {{ noloop_step }}; + // if (loop_iters * {{ noloop_step }} != params.s_q) { + // throw std::runtime_error("Incorrect seq len -- loop_iters * noloop_step != params.s_q"); + // } + assert(loop_iters * {{ noloop_step }} >= params.s_q); + dim3 grid(params.h, params.b, loop_iters); + {{ kernel_name }}_nl<<>>(params); +} + +{% endif %} // !use_multi_cta && has_noloop + +{% if not use_multi_cta %} // !use_multi_cta + +using Kernel_traits = fmha::{{ kernel_traits }}< + fmha::{{ instruction_traits }}, + {{ seq_len }}, + {{ head_size }}, + {{ head_size_v }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + {{ ctas_per_head }}, + {{ kernel_flags }}>; + +extern "C" +__global__ +void {{ kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_mhca_1xN(params); +} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + // const Launch_params &launch_params, // TODO + cudaStream_t stream){ + + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + dim3 grid(params.h, params.b); + {{ kernel_name }}<<>>(params); +} + +{% endif %} // !use_multi_cta + +{% endif %} // cross_mha + +{% if use_multi_cta %} // use_multi_cta + +// If that assert gets triggered - increase the value of MAX_STGS_PER_LOOP in "setup.py". +static_assert(Kernel_traits::Gmem_tile_o::STGS_PER_LOOP <= {{ MAX_STGS_PER_LOOP }}, ""); + +extern "C" +__global__ +void {{ kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_multi_cta(params); +} + +extern "C" +__global__ +void {{ causal_kernel_name }}({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_multi_cta(params); +} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + + assert(params.heads_per_wave != 0 && "Heads per wave is not set, but multi cta is requested"); + + // Clear the barriers and locks. + cudaMemsetAsync(params.counters, 0, 3*params.heads_per_wave*sizeof(int), stream); + + // We may use more than 48kB of shared memory. + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { + constexpr int smem_size = Kernel_traits_causal::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + // Launch one wave. + dim3 grid(Kernel_traits_causal::CTAS_PER_HEAD, params.heads_per_wave), block(Kernel_traits_causal::THREADS); + void *params_ = (void*) ¶ms; + FMHA_CHECK_CUDA(cudaLaunchCooperativeKernel((void*) &{{ causal_kernel_name }}, grid, block, (void**) ¶ms_, smem_size, stream)); + } else { + constexpr size_t smem_size = Kernel_traits::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + // Launch one wave. + dim3 grid(Kernel_traits::CTAS_PER_HEAD, params.heads_per_wave), block(Kernel_traits::THREADS); + void *params_ = (void*) ¶ms; + FMHA_CHECK_CUDA(cudaLaunchCooperativeKernel((void*) &{{ kernel_name }}, grid, block, (void**) ¶ms_, smem_size, stream)); + } +} + +{% endif %} // use_multi_cta + +void {{ launcher_name }}_get_max_heads_per_wave(int *heads_per_wave) { +{% if use_multi_cta %} // use_multi_cta + // Determine the number of SMs and CTAs. + int dev; + cudaGetDevice(&dev); + cudaDeviceProp props; + FMHA_CHECK_CUDA(cudaGetDeviceProperties(&props, dev)); + + // The number of CTAs per SM. + constexpr size_t smem_size = Kernel_traits::BYTES_PER_SMEM; + int ctas_per_sm; + FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, + &{{ kernel_name }}, + Kernel_traits::THREADS, + smem_size)); + + // The number of heads per wave. + *heads_per_wave = props.multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_HEAD; +{% else %} // use_multi_cta + *heads_per_wave = 0; +{% endif %} // use_multi_cta +} + +#else // CUDA_VERSION >= {{ min_cuda_version }} + +void {{ launcher_name }}( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + assert(false && "Unsupported CUDA version"); +} + +{% if has_noloop %} // has_noloop + +void {{ launcher_name }}_nl( + const {{ params_type }} ¶ms, + const Launch_params &launch_params, + cudaStream_t stream){ + assert(false && "Unsupported CUDA version"); +} + +{% endif %} // has_noloop + +#endif // CUDA_VERSION >= {{ min_cuda_version }} diff --git a/csrc/fmha_v2/templates/kernel_hopper.jinja b/csrc/fmha_v2/templates/kernel_hopper.jinja new file mode 100644 index 0000000000..b02434adfb --- /dev/null +++ b/csrc/fmha_v2/templates/kernel_hopper.jinja @@ -0,0 +1,379 @@ +{{ copyright }} + +//We can disable the FADD trick for archs with F2IP +{% if disable_fadd_trick %} +#ifdef USE_I2F_EMULATION_TRICK +#undef USE_I2F_EMULATION_TRICK +#endif + +#ifdef USE_F2I_EMULATION_TRICK +#undef USE_F2I_EMULATION_TRICK +#endif +{% endif %} + +#include + +#if CUDA_VERSION >= {{ min_cuda_version }} + +#include +{% if has_noloop %} +#include +{% endif %} + +{% if use_tma %} +// only included if tma is used. +#include +{% endif %} //use_tma + +{{ include_str }} +{{ local_ns_open }} +{{ bert_launch_params }} +{{ attn_mask_type_str }} + +using Traits_p = fmha::{{ instruction_traits_p }}; +using Traits_o = fmha::{{ instruction_traits_o }}; + +using Kernel_traits = {{ kernel_traits }}< + Traits_p, + Traits_o, + {{ seq_len }}, + {{ head_size }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + 2, + {{ kernel_flags }}>; + +using Kernel_traits_causal = {{ kernel_traits }}< + Traits_p, + Traits_o, + {{ seq_len }}, + {{ head_size }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + 3, + {{ kernel_flags }}>; + +using Kernel_traits_sliding_or_chunked_causal = {{ kernel_traits }}< + Traits_p, + Traits_o, + {{ seq_len }}, + {{ head_size }}, + {{ loop_step }}, + {{ warps_m }}, + {{ warps_n }}, + 4, + {{ kernel_flags }}>; + +{% if use_tma %} // use_tma + +{% if padding_mask %} // padding_mask + +extern "C" +__global__ +void {{ kernel_name }}(const __grid_constant__ {{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_tma(params); +} + +{% endif %} // padding_mask + +{% if causal_mask %} // causal_mask + +extern "C" +__global__ +void {{ causal_kernel_name }}(const __grid_constant__ {{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_tma(params); +} + +{% endif %} // causal mask + +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + +extern "C" +__global__ +void {{ sliding_or_chunked_causal_kernel_name }}(const __grid_constant__ {{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_tma(params); +} + +{% endif %} // sliding_or_chunked_causal_mask + +{% else %} + +{% if padding_mask %} + +extern "C" +__global__ +void {{ kernel_name }}(const __grid_constant__ {{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} + +{% endif %} // padding_mask + +{% if causal_mask %} // causal_mask + +extern "C" +__global__ +void {{ causal_kernel_name }}(const __grid_constant__ {{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} + +{% endif %} // causal mask + +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + +extern "C" +__global__ +void {{ sliding_or_chunked_causal_kernel_name }}(const __grid_constant__ {{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}(params); +} +{% endif %} + +{% endif %} // sliding_or_chunked_causal_mask + +void {{ launcher_name }}({{ fused_multihead_attention_params_v2_str }} ¶ms, + const Launch_params &launch_params, cudaStream_t stream){ + // setting TMA descriptors if needed. + // use_tma = {{ use_tma }} +{% if use_tma %} + // declare TMA desc for Q, K, V + typename fmha::Multiple_tma_descriptor<3> tma_desc_QKV; + + // GMEM pointers, the offset between each batch is d*3*h*seqlen + // qkv pointer + char *qkv_ptr = reinterpret_cast(params.qkv_ptr); + + // tensor size + uint32_t tensor_size_qkv[3]; + tensor_size_qkv[2] = 1; + tensor_size_qkv[1] = params.is_s_padded ? params.s * params.b : launch_params.seqlens[params.b]; + tensor_size_qkv[0] = (params.h + 2 * params.h_kv) * params.d; + + // box size for Q + uint32_t box_size_q[3]; + box_size_q[2] = 1; + box_size_q[1] = {{ loop_step }}; // STEP size + box_size_q[0] = {{ head_size }}; // head_size + + // box size for k and v + uint32_t box_size_kv[3]; + box_size_kv[2] = 1; + box_size_kv[1] = params.s; // S, should not be actual_s, OOB will be filled with zeros. + box_size_kv[0] = {{ head_size }}; // head_size + + // stride size + uint64_t tensor_stride_qkv[2]; + tensor_stride_qkv[0] = tensor_size_qkv[0] * Traits_p::BITS_PER_ELEMENT_A / 8; + tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; + + // traversal stride + uint32_t traversal_stride_qkv[3] = {1, 1, 1}; + + // OOB fill zeros + uint32_t oob_fill = 0; + + // FP32 to TF32 conversion disabled + uint32_t fp32_to_tf32 = 0; + + //setup the descriptors + + //setup the descriptor for Q + tma_desc_QKV.set_tma_desctriptor(reinterpret_cast(qkv_ptr), + fmha::cudaTmaDescFormat::F16_RN, // tma format (data type). For now hardcode to fp16 + fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, + fmha::cudaTmaDescSwizzle::SWIZZLE_128B, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, + tensor_size_qkv, + tensor_stride_qkv, + traversal_stride_qkv, + box_size_q, + oob_fill, + fp32_to_tf32, + ¶ms.tma_desc_q); + + // setup the descriptor for K + tma_desc_QKV.set_tma_desctriptor(reinterpret_cast(qkv_ptr), + fmha::cudaTmaDescFormat::F16_RN, // tma format (data type). For now hardcode to fp16 + fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, + fmha::cudaTmaDescSwizzle::SWIZZLE_128B, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, + tensor_size_qkv, + tensor_stride_qkv, + traversal_stride_qkv, + box_size_kv, + oob_fill, + fp32_to_tf32, + ¶ms.tma_desc_k); + + // setup the descriptor for V + tma_desc_QKV.set_tma_desctriptor(reinterpret_cast(qkv_ptr), + fmha::cudaTmaDescFormat::F16_RN, // tma format (data type). For now hardcode to fp16 + fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, + fmha::cudaTmaDescSwizzle::SWIZZLE_128B, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, + tensor_size_qkv, + tensor_stride_qkv, + traversal_stride_qkv, + box_size_kv, + oob_fill, + fp32_to_tf32, + ¶ms.tma_desc_v); + + +{% endif %} // use_tma + dim3 grid(params.h, params.b); + // Use the same smem_size for all traits. + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { +{% if causal_mask %} // causal_mask + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ causal_kernel_name }}<<>>({{ params_str }}); +{% endif %} // causal mask + } else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) { +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ sliding_or_chunked_causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ sliding_or_chunked_causal_kernel_name }}<<>>({{ params_str }}); +{% endif %} // sliding_or_chunked_causal_mask + } else { +{% if padding_mask %} // padding_mask + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ kernel_name }}<<>>({{ params_str }}); +{% endif %} // padding_mask + } +} + +{% if has_noloop %} + + +using Kernel_traits_nl = {{ kernel_traits }}< + Traits_p, + Traits_o, + {{ seq_len }}, + {{ head_size }}, + {{ noloop_step }}, + {{ warps_m }}, + {{ warps_n }}, + 2, + {{ kernel_flags }}>; + +using Kernel_traits_causal_nl = {{ kernel_traits }}< + Traits_p, + Traits_o, + {{ seq_len }}, + {{ head_size }}, + {{ noloop_step }}, + {{ warps_m }}, + {{ warps_n }}, + 3, + {{ kernel_flags }}>; + +using Kernel_traits_sliding_or_chunked_causal_nl = {{ kernel_traits }}< + Traits_p, + Traits_o, + {{ seq_len }}, + {{ head_size }}, + {{ noloop_step }}, + {{ warps_m }}, + {{ warps_n }}, + 4, + {{ kernel_flags }}>; + +{% if padding_mask %} // padding_mask + +extern "C" +__global__ +void {{ kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_nl(params); +} + +{% endif %} // padding_mask + +{% if causal_mask %} // causal_mask + +extern "C" +__global__ +void {{ causal_kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_nl(params); +} + +{% endif %} // causal mask + +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + +extern "C" +__global__ +void {{ sliding_or_chunked_causal_kernel_name }}_nl({{ params_type }} params){ + fused_multihead_attention::device_{{ kernel_variant }}_nl(params); +} + +{% endif %} // sliding_or_chunked_causal_mask + +void {{ launcher_name }}_nl({{ fused_multihead_attention_params_v2_str }} ¶ms, + const Launch_params& launch_params, cudaStream_t stream){ + constexpr int loop_iters = {{ seq_len }} / {{ noloop_step }}; + static_assert(loop_iters * {{ noloop_step }} == {{ seq_len }}, ""); + dim3 grid(params.h, params.b, loop_iters); + + // Use the same smem_size for all traits. + constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM; + if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { +{% if causal_mask %} // causal_mask + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ causal_kernel_name }}_nl<<>>({{ params_str }}); +{% endif %} // causal mask + } else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) { +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ sliding_or_chunked_causal_kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ sliding_or_chunked_causal_kernel_name }}_nl<<>>({{ params_str }}); +{% endif %} // sliding_or_chunked_causal_mask + } else { +{% if padding_mask %} // padding_mask + if( smem_size >= 48*1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}_nl, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } + {{ kernel_name }}_nl<<>>({{ params_str }}); +{% endif %} // padding_mask + } +} + +{% endif %} + +#else + +void {{ launcher_name }}(const {{ params_type }} ¶ms, cudaStream_t stream){ + assert(false && "Unsupported CUDA version"); +} + +{% if has_noloop %} + +void {{ launcher_name }}_nl(const {{ params_type }} ¶ms, cudaStream_t stream){ + assert(false && "Unsupported CUDA version"); +} + +{% endif %} + +#endif +{{ local_ns_close }} diff --git a/csrc/fmha_v2/templates/kernel_hopper_ws.jinja b/csrc/fmha_v2/templates/kernel_hopper_ws.jinja new file mode 100644 index 0000000000..4f7d532bc7 --- /dev/null +++ b/csrc/fmha_v2/templates/kernel_hopper_ws.jinja @@ -0,0 +1,402 @@ +{{ copyright }} + +#include +#include +#include +#include +#include + +#include +#include +#include + +{{ include_str }} +//////////////////////////////////////////////////////////////////////////////////////////////////// +{{ local_ns_open }} +#if CUDA_VERSION >= {{ min_cuda_version }} + +static constexpr int DMA2COMPUTE_DEPTH = 1; +{{ num_compute_groups_str }} +static constexpr bool USE_TMA_STORE = {{ use_tma_store_flag }}; + +{{ bert_launch_params }} +{{ attn_mask_type_str }} + +using Ktraits = {{ kernel_traits_header }} + {{ loop_step }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ q_tile_buffers }}, + {{ kv_tile_buffers }}, + NUM_COMPUTE_GROUPS, + DMA2COMPUTE_DEPTH, + 0, + {{ heads_interleaved_flag }}, + false, + {{ enable_mutex_flag }}, + {{ scheduling_mode }}, + {{ input_layout_flag }}, + USE_TMA_STORE, + {{ enable_attn_logit_softcapping_flag }}, + {{ return_softmax_stats_flag }}, + {{ enable_skip_softmax_flag }}, + {{ output_dtype_ }}, + {{ sage_block_size_q }}, + {{ sage_block_size_k }}, + {{ sage_block_size_v }}>; + +using Ktraits_causal = {{ kernel_traits_header }} + {{ loop_step }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ q_tile_buffers }}, + {{ kv_tile_buffers }}, + NUM_COMPUTE_GROUPS, + DMA2COMPUTE_DEPTH, + 1, + {{ heads_interleaved_flag }}, + {{ has_alibi }}, + {{ enable_mutex_flag }}, + {{ scheduling_mode }}, + {{ input_layout_flag }}, + USE_TMA_STORE, + {{ enable_attn_logit_softcapping_flag }}, + {{ return_softmax_stats_flag }}, + {{ enable_skip_softmax_flag }}, + {{ output_dtype_ }}>; + +using Ktraits_sliding_or_chunked_causal = {{ kernel_traits_header }} + {{ loop_step }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ q_tile_buffers }}, + {{ kv_tile_buffers }}, + NUM_COMPUTE_GROUPS, + DMA2COMPUTE_DEPTH, + 2, + {{ heads_interleaved_flag }}, + {{ has_alibi }}, + {{ enable_mutex_flag }}, + {{ scheduling_mode }}, + {{ input_layout_flag }}, + USE_TMA_STORE && false, + {{ enable_attn_logit_softcapping_flag }}, + {{ return_softmax_stats_flag }}, + {{ enable_skip_softmax_flag }}, + {{ output_dtype_ }}>; + +using Ktraits_custom_mask = {{ kernel_traits_header }} + {{ loop_step }}, + {{ kv_loop_step }}, + {{ head_size }}, + {{ head_size_v }}, + {{ q_tile_buffers }}, + {{ kv_tile_buffers }}, + NUM_COMPUTE_GROUPS, + DMA2COMPUTE_DEPTH, + 3, + {{ heads_interleaved_flag }}, + {{ has_alibi }}, + {{ enable_mutex_flag }}, + {{ scheduling_mode }}, + {{ input_layout_flag }}, + USE_TMA_STORE && false, + {{ enable_attn_logit_softcapping_flag }}, + {{ return_softmax_stats_flag }}, + {{ enable_skip_softmax_flag }}, + {{ output_dtype_ }}>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +{% if padding_mask %} // padding_mask + +using Shared = typename Ktraits::Shared; + +extern "C" +__global__ __launch_bounds__(Ktraits::THREADS, 1) +void {{ kernel_name }}( + const __grid_constant__ {{ params_type }} params){ + + extern __shared__ char smem_[]; + char *smem_aligned = fmha::align_1024(smem_); + + Shared *shared = reinterpret_cast(&smem_aligned[0]); + shared->init(threadIdx.x == 0); + __syncthreads(); + + // special trick to avoid wrap_sync (leads to illegal instruction) + int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + int tidx = threadIdx.x % 128; + + if( warp_group == NUM_COMPUTE_GROUPS ) { // dma + sched + + {{ setmaxnreg_dma_str }} + uint32_t elect_one = tidx == 0; + + // Need all threads involved when the dam group needs to transpose the v tile explicltly. + if constexpr ( Ktraits::DMA_GROUP_TRANSPOSE_V ) { + fmha::ws::DMA::Device dma_device(elect_one); + dma_device.{{ run_fct_name }}(params, shared); + } else { + fmha::ws::DMA::Device dma_device(elect_one); + if( tidx < 32 ) { + dma_device.{{ run_fct_name }}(params, shared); + } + } + + } else { // math + + {{ setmaxnreg_compute_str }} + + fmha::ws::Compute compute; + compute.run(warp_group, tidx, shared, params); + } +} + +{% endif %} // padding mask + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +{% if causal_mask %} // causal_mask + +using Shared_causal = typename Ktraits_causal::Shared; + +extern "C" +__global__ __launch_bounds__(Ktraits_causal::THREADS, 1) +void {{ causal_kernel_name }}( + const __grid_constant__ {{ params_type }} params){ + + extern __shared__ char smem_[]; + char *smem_aligned = fmha::align_1024(smem_); + + Shared_causal *shared = reinterpret_cast(&smem_aligned[0]); + shared->init(threadIdx.x == 0); + __syncthreads(); + + // special trick to avoid wrap_sync (leads to illegal instruction) + int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + int tidx = threadIdx.x % 128; + + if( warp_group == NUM_COMPUTE_GROUPS ) { // dma + sched + + {{ setmaxnreg_dma_str }} + uint32_t elect_one = tidx == 0; + + // Need all threads involved when the dam group needs to transpose the v tile explicltly. + if constexpr ( Ktraits_causal::DMA_GROUP_TRANSPOSE_V ) { + fmha::ws::DMA::Device dma_device(elect_one); + dma_device.{{ run_fct_name }}(params, shared); + } else { + fmha::ws::DMA::Device dma_device(elect_one); + if( tidx < 32 ) { + dma_device.{{ run_fct_name }}(params, shared); + } + } + + } else { // math + + {{ setmaxnreg_compute_str }} + + fmha::ws::Compute compute; + compute.run(warp_group, tidx, shared, params); + } +} + +{% endif %} // causal mask + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + +using Shared_sliding_or_chunked_causal = typename Ktraits_sliding_or_chunked_causal::Shared; + +extern "C" +__global__ __launch_bounds__(Ktraits_sliding_or_chunked_causal::THREADS, 1) +void {{ sliding_or_chunked_causal_kernel_name }}( + const __grid_constant__ {{ params_type }} params){ + + extern __shared__ char smem_[]; + char *smem_aligned = fmha::align_1024(smem_); + + Shared_sliding_or_chunked_causal *shared = + reinterpret_cast(&smem_aligned[0]); + shared->init(threadIdx.x == 0); + __syncthreads(); + + // special trick to avoid wrap_sync (leads to illegal instruction) + int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + int tidx = threadIdx.x % 128; + + if( warp_group == NUM_COMPUTE_GROUPS ) { // dma + sched + + {{ setmaxnreg_dma_str }} + uint32_t elect_one = tidx == 0; + + // Need all threads involved when the dam group needs to transpose the v tile explicltly. + if constexpr ( Ktraits_sliding_or_chunked_causal::DMA_GROUP_TRANSPOSE_V ) { + fmha::ws::DMA::Device dma_device(elect_one); + dma_device.{{ run_fct_name }}(params, shared); + } else { + fmha::ws::DMA::Device dma_device(elect_one); + if( tidx < 32 ) { + dma_device.{{ run_fct_name }}(params, shared); + } + } + + } else { // math + + {{ setmaxnreg_compute_str }} + + fmha::ws::Compute compute; + compute.run(warp_group, tidx, shared, params); + } +} + +{% endif %} // sliding_or_chunked_causal_mask + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +{% if custom_mask %} // custom_mask + +using Shared_custom_mask = typename Ktraits_custom_mask::Shared; + +extern "C" +__global__ __launch_bounds__(Ktraits_custom_mask::THREADS, 1) +void {{ custom_mask_kernel_name }}( + const __grid_constant__ {{ params_type }} params){ + + extern __shared__ char smem_[]; + char *smem_aligned = fmha::align_1024(smem_); + + Shared_custom_mask *shared = + reinterpret_cast(&smem_aligned[0]); + shared->init(threadIdx.x == 0); + __syncthreads(); + + // special trick to avoid wrap_sync (leads to illegal instruction) + int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + int tidx = threadIdx.x % 128; + + if( warp_group == NUM_COMPUTE_GROUPS ) { // dma + sched + + {{ setmaxnreg_dma_str }} + uint32_t elect_one = tidx == 0; + + // Need all threads involved when the dam group needs to transpose the v tile explicltly. + if constexpr ( Ktraits_custom_mask::DMA_GROUP_TRANSPOSE_V ) { + fmha::ws::DMA::Device dma_device(elect_one); + dma_device.{{ run_fct_name }}(params, shared); + } else { + fmha::ws::DMA::Device dma_device(elect_one); + if( tidx < 32 ) { + dma_device.{{ run_fct_name }}(params, shared); + } + } + + } else { // math + + {{ setmaxnreg_compute_str }} + + fmha::ws::Compute compute; + compute.run(warp_group, tidx, shared, params); + } +} + +{% endif %} // custom_mask + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void {{ launcher_name }}( + {{ fused_multihead_attention_params_v2_str }} ¶ms, + const Launch_params &launch_params, cudaStream_t stream){ + + {{ TMA_config }} + if( Ktraits::SCHEDULING_MODE > 0 ) { + FMHA_CHECK_CUDA(cudaMemsetAsync(params.tile_id_counter_ptr, 0, sizeof(uint32_t), stream)); + } + + dim3 block_size; + + if( Ktraits::SCHEDULING_MODE == 0 ) { + block_size.y = std::min(params.b * params.h, launch_params.multi_processor_count); + // distribute m steps to multiple blocks (fully utilize SMs) + // block.x = blocks that handle single head, block.y = blocks that handle different heads + size_t sms_per_head = (launch_params.multi_processor_count) / block_size.y; + // Take multiple compute groups into consideration. + size_t m_steps = size_t((params.s + {{ loop_step }} * NUM_COMPUTE_GROUPS - 1) / ({{ loop_step }} * NUM_COMPUTE_GROUPS)); + + // 2 * {{ bytes_per_elt }} stands for kv cache and {{ bytes_per_elt }} bytes per element. + size_t size_in_bytes = block_size.y * params.s * params.d * 2 * {{ bytes_per_elt }}; + if( size_in_bytes <= launch_params.device_l2_cache_size ) { + // strategy 1: limit to only 1 wave + block_size.x = std::min(m_steps, sms_per_head); + } else { + // strategy 2: fully unroll the q loops (contiguous blocks handle all q loops) + block_size.x = m_steps; + } + params.num_tiles = params.b * params.h; + } else if( Ktraits::SCHEDULING_MODE == 1 ) { + // Get the max total M steps + // Take multiple compute groups into consideration. + size_t m_steps = size_t((params.s + {{ loop_step }} * NUM_COMPUTE_GROUPS - 1) / ({{ loop_step }} * NUM_COMPUTE_GROUPS)); + params.num_tiles_per_head = static_cast(m_steps); + params.num_tiles = static_cast(m_steps * params.b * params.h); + if (launch_params.attention_mask_type == Attention_mask_type::CAUSAL) { + // 2 * {{ bytes_per_elt }} stands for kv cache and {{ bytes_per_elt }} bytes per element. + size_t size_in_bytes = params.b * params.h * params.s * params.d * 2 * {{ bytes_per_elt }}; + params.use_balanced_scheduling = (size_in_bytes <= launch_params.device_l2_cache_size); + } + + block_size.x = 1; + block_size.y = std::min(static_cast(params.num_tiles), launch_params.multi_processor_count); + } else { + assert(false && "Invalid SCHEDULING_MODE"); + } + + // Reuse the same bytes_per_smem for launching kernels. + constexpr int SMEM_BYTES = Ktraits::BYTES_PER_SMEM; + if( launch_params.attention_mask_type == Attention_mask_type::PADDING ) { +{% if padding_mask %} // padding_mask + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SMEM_BYTES)); + + {{ kernel_name }} + <<>>({{ params_str }}); +{% endif %} // padding_mask + } else if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) { +{% if causal_mask %} // causal_mask + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SMEM_BYTES)); + + {{ causal_kernel_name }} + <<>>({{ params_str }}); +{% endif %} // causal mask + } else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) { +{% if sliding_or_chunked_causal_mask %} // sliding_or_chunked_causal_mask + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ sliding_or_chunked_causal_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SMEM_BYTES)); + + {{ sliding_or_chunked_causal_kernel_name }} + <<>>({{ params_str }}); +{% endif %} // sliding_or_chunked_causal_mask + } else if( launch_params.attention_mask_type == Attention_mask_type::CUSTOM_MASK ) { +{% if custom_mask %} // custom_mask + FMHA_CHECK_CUDA(cudaFuncSetAttribute({{ custom_mask_kernel_name }}, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SMEM_BYTES)); + + {{ custom_mask_kernel_name }} + <<>>({{ params_str }}); +{% endif %} // custom mask + } + +} + +#endif +{{ local_ns_close }} diff --git a/csrc/fmha_v2_jit_binding.cu b/csrc/fmha_v2_jit_binding.cu new file mode 100644 index 0000000000..f0eb500eec --- /dev/null +++ b/csrc/fmha_v2_jit_binding.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// FMHAv2 JIT Binding +// This file exports the fmha_v2_run function via TVM FFI + +#include + +#include "tvm_ffi_utils.h" + +using tvm::ffi::Optional; +using Attention_input_layout = fmha::Attention_input_layout; + +void fmha_v2_run(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, ffi::TensorView o, + ffi::TensorView workspace_buffer, size_t workspace_buffer_size_in_bytes, + Optional maybe_block_tables, int page_size, + ffi::TensorView seq_lens, ffi::TensorView cum_seq_lens_q, + ffi::TensorView cum_seq_lens_kv, const std::string& input_layout_str, + int max_q_len, int max_kv_len, int batch_size, const std::string& mask_mode_str, + float scale_softmax, float scale_bmm1, float scale_bmm2, int window_left, + int chunked_attention_size, bool has_alibi, float softcapping_scale, + float skip_softmax_threshold_scale_factor, ffi::TensorView scale_bmm2_d, + Optional softmax_stats, Optional sinks); + +// FMHAv2 attention operator +TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, fmha_v2_run); diff --git a/csrc/fmha_v2_run.cu b/csrc/fmha_v2_run.cu new file mode 100644 index 0000000000..3dfff9b967 --- /dev/null +++ b/csrc/fmha_v2_run.cu @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmha_v2_api.h" + +// #include "fmha_v2_dispatcher.h" +#include "tvm_ffi_utils.h" + +using tvm::ffi::Optional; +namespace ffi = tvm::ffi; + +using Launch_params = bert::Fused_multihead_attention_launch_params; +using Attention_mask_type = fmha::Attention_mask_type; +using Attention_input_layout = fmha::Attention_input_layout; +using Kv_block_array = fmha::Kv_block_array; +using AlignedAllocator = flashinfer::AlignedAllocator; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// set_params - copied exactly from fused_multihead_attention.cpp +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline void set_params( + bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params, + // types + Data_type data_type, Data_type acc_type, Data_type output_dtype, + // attention input layout + Attention_input_layout input_layout, + // sizes + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, + const size_t d, const size_t dv, const size_t total, const size_t num_grouped_heads, + const size_t sliding_window_size, const size_t chunked_attention_size, + // paged kv cache block size. + const size_t tokens_per_block, + // device pointers + void* qkv_packed_d, + // contiguous q. + void* q_d, + // separate k. + void* k_d, + // separate v. + void* v_d, + // contiguous kv. + void* kv_d, + // start address of the paged kv pool. + void* paged_kv_pool_ptr, + // offsets for different blocks in terms of the start address. + int32_t* paged_block_offsets, + // mask input. + void* packed_mask_d, void* cu_mask_rows_d, + // attention sinks. + void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, + void* p_d, void* s_d, void* softmax_stats_d, void* scale_bmm2_d, + // scale factors + float const scale_bmm1, float const scale_softmax, float const scale_bmm2, + float const softcapping_scale_bmm1, + // flags + bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, + bool const has_alibi, float const skip_softmax_threshold_scale_factor) { + memset(¶ms, 0, sizeof(params)); + + params.o_ptr = o_packed_d; + params.o_stride_in_bytes = get_size_in_bytes(h * dv, output_dtype); + + if (interleaved) { + params.q_stride_in_bytes = total; + params.o_stride_in_bytes = total; + } + + if (input_layout == Attention_input_layout::PACKED_QKV) { + // For grouped- or multi-query attention (h denotes num_q_heads; h' denotes h_kv): + // qkv_layout = [b, s, [q_hd, k_h'd, v_h'd]] + // qkv_stride = (h+2*h')d * bytes_per_elt + // Otherwise: + // qkv_layout = [b, s, 3, h, d] or [b, s, h, 3, d] + // qkv_stride = 3hd * bytes_per_elt + params.qkv_ptr = qkv_packed_d; + params.q_stride_in_bytes = params.k_stride_in_bytes = params.v_stride_in_bytes = + get_size_in_bytes(h * d + h_kv * d + h_kv * dv, data_type); + } else { + // Layout [B, S, H, D]. + params.q_ptr = q_d; + params.q_stride_in_bytes = get_size_in_bytes(h * d, data_type); + + if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) { + // Layout [B, S, 2, H, D]. + params.kv_ptr = kv_d; + params.k_stride_in_bytes = params.v_stride_in_bytes = + get_size_in_bytes(h_kv * (d + dv), data_type); + } else if (input_layout == Attention_input_layout::Q_PAGED_KV) { + int max_blocks_per_sequence = (s_kv + tokens_per_block - 1) / tokens_per_block; + params.paged_kv_cache = + Kv_block_array(b, max_blocks_per_sequence, tokens_per_block, + get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type), + paged_kv_pool_ptr); + params.paged_kv_cache.mBlockOffsets = paged_block_offsets; + params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type); + params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type); + } else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) { + // Layout [B, S, H_kv, D]. + params.k_ptr = k_d; + // Layout [B, S, H_kv, Dv]. + params.v_ptr = v_d; + params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type); + params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type); + } + } + + // Packed mask. + params.packed_mask_ptr = packed_mask_d; + // The N dimension has to be aligned. + params.packed_mask_stride_in_bytes = + (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8; + + // Attention sinks. + params.attention_sinks = reinterpret_cast(attention_sinks_d); + +#if defined(STORE_P) + params.p_ptr = p_d; + params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type); +#endif // defined(STORE_P) + +#if defined(STORE_S) + params.s_ptr = s_d; + params.s_stride_in_bytes = get_size_in_bytes(b * h * s_kv, data_type); +#endif // defined(STORE_S) + + params.softmax_stats_ptr = softmax_stats_d; + params.softmax_stats_stride_in_bytes = get_size_in_bytes(h * 2, DATA_TYPE_FP32); + + // Set the dimensions. + params.b = b; + params.h = h; + params.s = s_q; + params.d = d; + params.dv = dv; + params.num_grouped_heads = num_grouped_heads; + params.sliding_window_size = sliding_window_size; + assert((chunked_attention_size == 0 || + (chunked_attention_size & (chunked_attention_size - 1)) == 0) && + "chunked_attention_size has to be a power of 2"); + params.log2_chunked_attention_size = + chunked_attention_size > 0 ? std::log2(chunked_attention_size) : 0; + + // cumulative q or kv sequence lengths. + params.cu_q_seqlens = static_cast(cu_q_seqlens_d); + params.cu_kv_seqlens = static_cast(cu_kv_seqlens_d); + // cumulative mask sequence lengths. + params.cu_mask_rows = static_cast(cu_mask_rows_d); + + // Set the different scale values. + Data_type scale_type1 = + (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? acc_type : DATA_TYPE_FP32; + Data_type scale_softmax_type = scale_type1; + Data_type scale_type2 = + (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? data_type : DATA_TYPE_FP32; + if (data_type == DATA_TYPE_E4M3) { + scale_type1 = acc_type; + scale_type2 = acc_type; + } + + // Fuse 1.0f / softcapping_scale into scale_bmm1. + bool const enable_attn_logit_softcapping = softcapping_scale_bmm1 != 0.f; + float fused_scale_bmm1 = + enable_attn_logit_softcapping ? scale_bmm1 / softcapping_scale_bmm1 : scale_bmm1; + + // use specialized hopper kernels without alibi support. + // alibi or softcapping_scale cannot utilize the exp2f with fused_scale optimization. + if (launch_params.warp_specialization && !has_alibi && !enable_attn_logit_softcapping) { + set_alpha(params.scale_bmm1, fused_scale_bmm1 * float(M_LOG2E), DATA_TYPE_FP32); + } else { + set_alpha(params.scale_bmm1, fused_scale_bmm1, scale_type1); + } + set_alpha(params.scale_softmax, scale_softmax, scale_softmax_type); + set_alpha(params.scale_bmm2, scale_bmm2, scale_type2); + // NOTE: scale_bmm2_d is now pre-populated from Python to avoid cudaMemcpy synchronization. + // The Python side calls create_scale_bmm2_d_tensor() which replicates set_alpha logic. + params.scale_bmm2_d = reinterpret_cast(scale_bmm2_d); + params.softcapping_scale_bmm1 = softcapping_scale_bmm1; + + // attention type, h_kv < h if MQA or GQA + params.h_kv = h_kv; + assert(h % h_kv == 0 && "MQA/GQA needs h to be divisible by h_kv!"); + params.h_q_per_kv = h / h_kv; + params.has_alibi = has_alibi; + params.alibi_params = fmha::AlibiParams(h); + + // Set flags + params.is_s_padded = is_s_padded; + params.use_int8_scale_max = use_int8_scale_max; + + // Do we enable the trick to replace I2F with FP math in the 2nd GEMM? + if (data_type == DATA_TYPE_INT8) { + params.enable_i2f_trick = -double(1 << 22) * double(scale_bmm2) <= -128.f && + double(1 << 22) * double(scale_bmm2) >= 127.f; + } + + // Skip-softmax attention + params.skip_softmax_threshold_scale_factor = skip_softmax_threshold_scale_factor; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// determine_launch_params - copied exactly from fused_multihead_attention.cpp +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline void determine_launch_params( + Launch_params& launch_params, Data_type data_type, int sm, const size_t s, const size_t d, + const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout, + bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma, + bool const force_non_flash_attention, bool const force_non_warp_specialization, + bool const force_non_granular_tiling, bool const force_fp32_acc, + float const skip_softmax_threshold_scale_factor, + // device props + const cudaDeviceProp props) { + // Set launch params to choose kernels + launch_params.ignore_b1opt = ignore_b1opt; + launch_params.force_unroll = force_unroll; + launch_params.force_fp32_acc = force_fp32_acc; + launch_params.interleaved = interleaved; + launch_params.attention_mask_type = attention_mask_type; + launch_params.attention_input_layout = input_layout; + + // Set SM count and L2 cache size (used to determine launch blocks/grids to maximum performance) + launch_params.multi_processor_count = props.multiProcessorCount; + launch_params.device_l2_cache_size = props.l2CacheSize; + + // threshold for adopting flash attention or warp_specialized kernels. + launch_params.flash_attention = + (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && + (s >= 16 && d >= 16) && !force_non_flash_attention; + + // enable warp_speialized kernels when s >= 512 on hopper + // note that warp_speialized kernels need flash attention + tma + launch_params.warp_specialization = + (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && + sm == 90 && launch_params.flash_attention && !force_non_warp_specialization; + // warp specialization kernels on hopper need tma + launch_params.use_tma = use_tma || launch_params.warp_specialization; + + // use granular tiling on Ampere-style flash attention + launch_params.use_granular_tiling = !force_non_granular_tiling && launch_params.flash_attention && + !launch_params.warp_specialization && sm >= 80; + + if (launch_params.use_granular_tiling && (data_type == DATA_TYPE_E4M3 && sm == 80)) { + printf( + "Fallback to non-granular-tiling kernels as tiled e4m3 kernels" + "are not supported on Ada currently.\n"); + launch_params.use_granular_tiling = false; + } + + // Enable skip softmax attention or not. + launch_params.enable_skip_softmax = skip_softmax_threshold_scale_factor > 0.f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper function to convert DLDataType to Data_type enum +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline Data_type dltype_to_data_type(DLDataType dtype) { + if (dtype.code == kDLFloat && dtype.bits == 16) { + return DATA_TYPE_FP16; + } else if (dtype.code == kDLBfloat && dtype.bits == 16) { + return DATA_TYPE_BF16; + } else if (dtype.code == kDLFloat8_e4m3fn && dtype.bits == 8) { + return DATA_TYPE_E4M3; + } else if (dtype.code == kDLFloat && dtype.bits == 32) { + return DATA_TYPE_FP32; + } else if (dtype.code == kDLInt && dtype.bits == 8) { + return DATA_TYPE_INT8; + } + assert(false && "Unsupported data type"); + return DATA_TYPE_FP16; +} + +static inline Attention_mask_type string_to_mask_type(const std::string& s) { + if (s == "padding") return Attention_mask_type::PADDING; + if (s == "causal") return Attention_mask_type::CAUSAL; + if (s == "sliding_window" || s == "chunked") + return Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL; + if (s == "custom") return Attention_mask_type::CUSTOM_MASK; + return Attention_mask_type::CAUSAL; // default +} + +static inline Attention_input_layout string_to_input_layout(const std::string& s) { + if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV; + if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV; + if (s == "q_paged_kv_nhd") return Attention_input_layout::Q_PAGED_KV; + if (s == "q_paged_kv_hnd") return Attention_input_layout::Q_PAGED_KV; + if (s == "separate_q_k_v") return Attention_input_layout::SEPARATE_Q_K_V; + throw std::invalid_argument("Unsupported input_layout: " + s); +} + +void fmha_v2_run( + ffi::TensorView q, // [batch, s_q, num_heads, head_dim] + ffi::TensorView k, // [batch, s_kv, num_kv_heads, head_dim] + ffi::TensorView v, // [batch, s_kv, num_kv_heads, head_dim_v] + ffi::TensorView o, // [batch, s_q, num_heads, head_dim_v] + ffi::TensorView workspace_buffer, size_t workspace_buffer_size_in_bytes, + Optional maybe_block_tables, // [batch, num_pages] + int page_size, + ffi::TensorView seq_lens, // [batch] + ffi::TensorView cum_seq_lens_q, // [batch + 1] + ffi::TensorView cum_seq_lens_kv, // [batch + 1] + const std::string& input_layout_str, int max_q_len, int max_kv_len, int batch_size, + const std::string& mask_mode_str, float scale_softmax, float scale_bmm1, float scale_bmm2, + int window_left, int chunked_attention_size, bool has_alibi, float softcapping_scale, + float skip_softmax_threshold_scale_factor, + ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32 + Optional softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum) + Optional sinks) { + Attention_input_layout input_layout = string_to_input_layout(input_layout_str); + Attention_mask_type attention_mask_type = string_to_mask_type(mask_mode_str); + Data_type output_dtype = dltype_to_data_type(o.dtype()); + // Get device properties + CudaDevice device; + int sm = device.sm; + // Map SM12x variants (e.g. SM121 on DGX Spark) to base SM120 for kernel dispatch. + // CudaDevice computes sm = major*10 + minor, but all SM12x share the same Ampere-era + // MMA instructions and dispatch entries are generated for sm==120. + if (sm > 120 && sm < 130) { + sm = 120; + } + cudaDeviceProp props = device.props; + + cudaStream_t stream = static_cast(get_stream(q.device())); + + // Extract dimensions based on input_layout: + // - PACKED_QKV: q is 4D [total_tokens, 3, num_heads, head_dim], k/v are same as q + // - Q_PAGED_KV: q is 3D [total_tokens, num_heads, head_dim], k/v are 4D paged + // - SEPARATE_Q_K_V: q/k/v are all 3D [total_tokens, num_heads, head_dim] + // - CONTIGUOUS_Q_KV: q is 3D [total_tokens, H, D], k is 4D [total_tokens, 2, H_kv, D] + const size_t b = batch_size; + size_t h, h_kv, d, dv; + if (input_layout == Attention_input_layout::PACKED_QKV) { + // q is 4D: [total_tokens, 3, H, D] + h = q.shape()[2]; // num_heads + h_kv = q.shape()[2]; // same as h for packed QKV (MHA) + d = q.shape()[3]; // head_dim_qk + dv = q.shape()[3]; // head_dim_v (same as d for standard attention) + } else if (input_layout == Attention_input_layout::Q_PAGED_KV) { + // q is 3D: [total_tokens, H, D], k/v are 4D paged: [num_pages, H_kv, page_size, D] + h = q.shape()[1]; + h_kv = k.shape()[1]; + d = q.shape()[2]; + dv = v.shape()[3]; + } else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) { + // q is 3D: [total_tokens, H, D], k is 4D: [total_tokens, 2, H_kv, D] + // k holds the combined KV tensor where dim 1 = 2 (K and V interleaved) + h = q.shape()[1]; + h_kv = k.shape()[2]; // KV shape is [tokens, 2, H_kv, D] + d = q.shape()[2]; + dv = k.shape()[3]; // D from KV tensor + } else { + // SEPARATE_Q_K_V: all 3D ragged [total_tokens, H, D] + h = q.shape()[1]; + h_kv = k.shape()[1]; + d = q.shape()[2]; + dv = v.shape()[2]; + } + + const size_t s_q = max_q_len; + const size_t s_kv = max_kv_len; + const size_t s = s_kv; // For compatibility with existing code + + // Determine data types from input tensors + Data_type data_type = dltype_to_data_type(q.dtype()); + Data_type acc_type = + (data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) ? DATA_TYPE_FP32 : data_type; + + int tokens_per_block = page_size; + float softcapping_scale_bmm1 = softcapping_scale; + + // BF16 and E4M3 require FP32 accumulation, but FP16 kernels use FP16 accumulation. + // The generated kernel dispatch expects: + // - FP16 kernels: !force_fp32_acc (force_fp32_acc = false) + // - BF16 kernels: force_fp32_acc (force_fp32_acc = true) + // - E4M3 kernels: force_fp32_acc (force_fp32_acc = true) + bool force_fp32_acc = (data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3); + + // Sliding window attention parameters + if (window_left > 0 && window_left < static_cast(s)) { + assert(chunked_attention_size == 0 && + "chunked_attention_size should not be used when sliding_window_size is set"); + attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL; + } + // Chunked attention. + if (chunked_attention_size > 0) { + assert((chunked_attention_size & (chunked_attention_size - 1)) == 0 && + "chunked_attention_size has to be a power of 2"); + attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL; + } + size_t sliding_window_size = size_t(INT_MAX); + if (attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL) { + if (window_left != -1) { + // Adjust sliding_window_size so that FMHA v2 matches FlashInfer's window_left semantics. + // Set sliding_window_size = window_left + 1. + sliding_window_size = size_t(window_left + 1); + } + } + + // total_q_tokens is inferred from q.shape()[0], which is always the total Q token count + // across all ragged layouts (PACKED_QKV, CONTIGUOUS_Q_KV, SEPARATE_Q_K_V, Q_PAGED_KV). + uint32_t total_q_tokens = static_cast(q.shape()[0]); + uint32_t total = total_q_tokens; // Used for stride calculations in interleaved mode + + AlignedAllocator allocator(workspace_buffer.data_ptr(), workspace_buffer_size_in_bytes); + + // Validation for softmax save with MLA + if (softmax_stats.has_value()) { + bool is_MLA = (d == 192 && dv == 128); + if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) || + (is_MLA && input_layout != Attention_input_layout::SEPARATE_Q_K_V)) { + fprintf(stderr, + "For normal attention, only CONTIGUOUS_Q_KV layout supports saving softmax stats. " + "For MLA only SEPARATE_Q_K_V layout supports saving softmax stats.\n"); + exit(1); + } + } + + // Validate different q and kv lengths + if (s_q != s_kv) { + assert(input_layout != Attention_input_layout::PACKED_QKV && + "Packed QKV input layout is not supported with different q and kv lengths."); + assert(s_kv >= s_q && "q seqlen has to be smaller than or equal to the kv seqlen!"); + } + + // Set the attention scale (default: 1/sqrt(d)) + if (scale_bmm1 == 0.f) { + scale_bmm1 = 1.f / sqrtf(static_cast(d)); + } + + // Adjust softmax scale for different data types + if (data_type == DATA_TYPE_FP16 && scale_softmax == 0.f) { + scale_softmax = 1.f; + } else if (data_type == DATA_TYPE_INT8 && scale_softmax == 0.f) { + scale_softmax = std::max(512.f, static_cast(s)); + } else if (data_type == DATA_TYPE_E4M3 && scale_softmax == 0.f) { + scale_softmax = 1.f; + } + + // Enable causal mask if using alibi + if (has_alibi && attention_mask_type == Attention_mask_type::PADDING) { + attention_mask_type = Attention_mask_type::CAUSAL; + } + + // BF16 only supports FP32 accumulation + if (data_type == DATA_TYPE_BF16 && acc_type != DATA_TYPE_FP32) { + fprintf(stderr, "Only FP32 accumulation is supported for BF16 I/O\n"); + exit(1); + } + + // Determine the launch params to select kernels + Launch_params launch_params; + determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, + false, false, false, false, false, false, false, force_fp32_acc, + skip_softmax_threshold_scale_factor, props); + + // The decomposition of threads and warps for BMM1. + size_t warps_m, warps_n, warps_k; + std::tie(warps_m, warps_n, warps_k) = get_warps(launch_params, sm, data_type, s, b, d, 2); + + // For multi-CTA cases, determine the size of the CTA wave. + int heads_per_wave, ctas_per_head; + get_grid_size(heads_per_wave, ctas_per_head, sm, data_type, b, s, h, d, + false, // disable multi-cta kernels by default + 2); + + // The number of threads per CTA. + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; + // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. + size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m); + // The number of mmas in the N dimension. + size_t mmas_n = (s + 16 * warps_n - 1) / (16 * warps_n); + // The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA. + size_t packed_mask_size = b * mmas_m * threads_per_cta; + + // Flash attention on Ampere and Hopper, which supports multiple mmas_n + if (attention_mask_type == Attention_mask_type::CUSTOM_MASK) { + // We need to align q and k sequence lengths. + size_t rounded_q_s = align_to(s, size_t(fmha::FLASH_ATTEN_MASK_M_ALIGNMENT)); + size_t rounded_k_s = align_to(s, size_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT)); + // The number of mmas in the M dimension (MMA_M = 64). + mmas_m = rounded_q_s / fmha::FLASH_ATTEN_MASK_MMA_M; + // The number of mmas in the N dimension (MMA_N = 64). + mmas_n = rounded_k_s / fmha::FLASH_ATTEN_MASK_MMA_N; + // Each thread holds 32 bit (2 rows, 16 cols -> 8 core MMAs) in one MMA here. + packed_mask_size = b * mmas_m * mmas_n * threads_per_cta; + } + // The size in bytes. + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + + // Packed mask (allocated conditionally for CUSTOM_MASK) + void* packed_mask_d = + (attention_mask_type == Attention_mask_type::CUSTOM_MASK) + ? allocator.aligned_alloc(packed_mask_size_in_bytes, 128, "packed_mask_d") + : nullptr; + + // NOTE: scale_bmm2_d is now passed as a pre-populated tensor from Python + // to avoid cudaMemcpy synchronization in set_params(). + + // Softmax stats: stores (max, sum) per token, 2 floats per (b, s_q, h) + // Write directly to user-provided tensor when available, otherwise use workspace. + void* softmax_stats_ptr; + if (softmax_stats.has_value()) { + softmax_stats_ptr = softmax_stats.value().data_ptr(); + } else { + const size_t softmax_stats_size = 2 * sizeof(float) * b * s_q * h; + softmax_stats_ptr = allocator.aligned_alloc(softmax_stats_size, 128, "softmax_stats_d"); + } + void* attention_sinks_d = sinks.has_value() ? sinks.value().data_ptr() : nullptr; + + // Initialize pointers for different input layouts + void* qkv_packed_d = nullptr; + void* q_d = nullptr; + void* k_d = nullptr; + void* v_d = nullptr; + void* contiguous_kv_d = nullptr; + void* kv_cache_pool_ptr = nullptr; + int32_t* kv_cache_block_offsets_d = nullptr; + + // For Q_PAGED_KV layout, block_tables is pre-expanded on the Python side from [B, M] to [B, 2, M] + // where [:, 0, :] contains K offsets and [:, 1, :] contains V offsets. + int block_table_max_blocks = 0; + + switch (input_layout) { + case Attention_input_layout::PACKED_QKV: + qkv_packed_d = q.data_ptr(); + break; + case Attention_input_layout::CONTIGUOUS_Q_KV: + q_d = q.data_ptr(); + contiguous_kv_d = k.data_ptr(); + break; + case Attention_input_layout::SEPARATE_Q_K_V: + q_d = q.data_ptr(); + k_d = k.data_ptr(); + v_d = v.data_ptr(); + break; + case Attention_input_layout::Q_PAGED_KV: { + q_d = q.data_ptr(); + kv_cache_pool_ptr = k.data_ptr(); + + if (maybe_block_tables.has_value()) { + // block_tables is pre-expanded on Python side with shape [B, 2, M] + // where M is max_blocks_per_sequence + ffi::TensorView block_tables = maybe_block_tables.value(); + block_table_max_blocks = block_tables.shape()[2]; // shape is [B, 2, M] + kv_cache_block_offsets_d = static_cast(block_tables.data_ptr()); + } + } break; + default: + assert(false && "Invalid input layout"); + break; + } + + bert::Fused_multihead_attention_params_v2 params_v2; + set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s, + h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size, + // Paged kv cache. + tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, + kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d, + static_cast(cum_seq_lens_kv.data_ptr()), + static_cast(cum_seq_lens_q.data_ptr()), o.data_ptr(), nullptr, nullptr, + softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, scale_softmax, scale_bmm2, + softcapping_scale_bmm1, false, false, false, has_alibi, + skip_softmax_threshold_scale_factor); + + // For Q_PAGED_KV layout, override mMaxBlocksPerSeq to match the actual block_tables stride + // that we used when expanding the block offsets from [B, M] to [B, 2, M] + if (input_layout == Attention_input_layout::Q_PAGED_KV && block_table_max_blocks > 0) { + params_v2.paged_kv_cache.mMaxBlocksPerSeq = block_table_max_blocks; + } + + // Total number of Q tokens is needed to set TMA desc on the host. + launch_params.total_q_seqlen = total_q_tokens; + // set enable_attn_logit_softcapping to select the right kernel. + launch_params.enable_attn_logit_softcapping = softcapping_scale_bmm1 != 0.f; + + // Compute sizes for conditional allocations + size_t counters_sz = (ctas_per_head > 1) ? heads_per_wave * sizeof(int) : 0; + size_t softmax_scratch_sz = + (ctas_per_head > 1) ? heads_per_wave * ctas_per_head * threads_per_cta * sizeof(float) : 0; + size_t o_scratch_sz = (ctas_per_head > 1 && data_type != DATA_TYPE_FP16) + ? heads_per_wave * threads_per_cta * MAX_STGS_PER_LOOP * sizeof(uint4) + : 0; + + // Allocate barriers and locks + void* counters_d = (counters_sz > 0) + ? allocator.aligned_alloc(3 * counters_sz, 16, "counters_d") + : nullptr; + // Allocate scratch storage for softmax + void* max_scratch_d = (softmax_scratch_sz > 0) ? allocator.aligned_alloc( + softmax_scratch_sz, 128, "max_scratch_d") + : nullptr; + void* sum_scratch_d = (softmax_scratch_sz > 0) ? allocator.aligned_alloc( + softmax_scratch_sz, 128, "sum_scratch_d") + : nullptr; + // Allocate temporary storage for the parallel reduction + void* o_scratch_d = (o_scratch_sz > 0) + ? allocator.aligned_alloc(o_scratch_sz, 128, "o_scratch_d") + : nullptr; + // Allocate tile id for dynamic scheduling + void* tile_id_counter_d = + allocator.aligned_alloc(sizeof(uint32_t), 16, "tile_id_counter_d"); + + // The number of heads computed per wave. + params_v2.heads_per_wave = heads_per_wave; + + // Barriers for the global sync in the multi-CTA kernel(s). + params_v2.counters = (int*)counters_d + 0 * heads_per_wave; + params_v2.max_barriers = (int*)counters_d + 0 * heads_per_wave; + params_v2.sum_barriers = (int*)counters_d + 1 * heads_per_wave; + params_v2.locks = (int*)counters_d + 2 * heads_per_wave; + + // Scratch storage for softmax. + params_v2.max_scratch_ptr = (float*)max_scratch_d; + params_v2.sum_scratch_ptr = (float*)sum_scratch_d; + + // Scratch storage for output. + params_v2.o_scratch_ptr = (int*)o_scratch_d; + + // Tile id counter for dynamic scheduling + params_v2.tile_id_counter_ptr = (uint32_t*)tile_id_counter_d; + + // V2 Custom Mask Packing (only if using CUSTOM_MASK) + // Note: You need to populate packed_mask_d with your custom mask data here + // using pack_flash_attention_mask() or provide pre-packed mask + + // Run the V2 kernel with runtime dispatch based on dtype and head dimensions + run_fmha_v2(params_v2, launch_params, data_type, output_dtype, sm, stream); +} diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 26b5f97894..8fa98adb62 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -130,6 +130,7 @@ from .prefill import ( single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse, ) +from .prefill import trtllm_fmha_v2_prefill as trtllm_fmha_v2_prefill from .quantization import packbits as packbits from .quantization import segment_packbits as segment_packbits from .rope import apply_llama31_rope as apply_llama31_rope diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 2f6829ee0f..5f154122ed 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -55,7 +55,10 @@ from .attention import get_single_decode_uri as get_single_decode_uri from .attention import get_single_prefill_uri as get_single_prefill_uri from .attention import gen_trtllm_gen_fmha_module as gen_trtllm_gen_fmha_module -from .attention import get_trtllm_fmha_v2_module as get_trtllm_fmha_v2_module +from .attention import gen_fmha_v2_module as gen_fmha_v2_module +from .attention import ( + gen_trtllm_fmha_v2_sm120_module as gen_trtllm_fmha_v2_sm120_module, +) from .core import JitSpec as JitSpec from .core import JitSpecStatus as JitSpecStatus from .core import JitSpecRegistry as JitSpecRegistry diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index 5f77725973..82af8ba97f 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -45,9 +45,9 @@ from .modules import get_pod_uri as get_pod_uri from .modules import get_single_decode_uri as get_single_decode_uri from .modules import get_single_prefill_uri as get_single_prefill_uri -from .modules import get_trtllm_fmha_v2_module as get_trtllm_fmha_v2_module from .modules import gen_trtllm_gen_fmha_module as gen_trtllm_gen_fmha_module -from .modules import gen_trtllm_fmha_v2_module as gen_trtllm_fmha_v2_module +from .modules import gen_trtllm_fmha_v2_sm120_module as gen_trtllm_fmha_v2_sm120_module +from .modules import gen_fmha_v2_module as gen_fmha_v2_module from .modules import ( gen_batch_prefill_attention_sink_module as gen_batch_prefill_attention_sink_module, get_batch_prefill_attention_sink_uri as get_batch_prefill_attention_sink_uri, diff --git a/flashinfer/jit/attention/fmha_v2/fmha_library.py b/flashinfer/jit/attention/fmha_v2/fmha_library.py new file mode 100644 index 0000000000..4b422fedfb --- /dev/null +++ b/flashinfer/jit/attention/fmha_v2/fmha_library.py @@ -0,0 +1,1337 @@ +import itertools +import pathlib +from typing import Any, Optional, Tuple +from ... import env as jit_env +from ....compilation_context import CompilationContext +from dataclasses import dataclass, asdict +from .utils import ( + get_effective_sm_and_name, + get_hopper_instruction_traits, + get_reg_count, + enable_mutex, + enable_tma_store, + selected_mask_types, + pythonBoolean2cpp, + dtype2bytes, + dtype2traits, + hopper_dtype2traits, + MAX_STGS_PER_LOOP, + dtype2OutputType, + InputLayout, + encode_name, + dtype2typename, + copyright, +) + +from ...utils import write_if_different + +import jinja2 + + +def select_kv_loop_step(head_size: int) -> int: + """ + Select the KV loop step based on head size. + + For warp-specialized Hopper kernels: + - Small heads (32-64): 256 step for better occupancy + - Medium heads (72-128): 128 step + - Large heads (160-256): 64 step to fit in registers + """ + if head_size <= 64: + return 256 + elif head_size <= 128: + return 128 + else: + return 64 + + +@dataclass(frozen=True) +class FMHAv2KernelSpec: + sm: int + dtype: str + seq_len: int + head_size: int + warps_m: int + warps_n: int + version: int + interleaved: bool + ldgsts_q: int + ldgsts_k: int + ldgsts_v: int + share_smem_k_v: bool + loop_step: int + has_noloop: bool + noloop_step: int + unroll_threshold: int + has_scale_max: bool + ctas_per_head: int = 1 + sm_mma: int = 1 + head_interleaved: bool = True + flash_attention: bool = False + kv_loop_step: int = 64 + limit_qk_fragments: bool = False + limit_v_fragments: bool = False + tiled: int = 0 + warp_specialization: bool = False + q_tile_buffers: int = 1 + kv_tile_buffers: int = 1 + scheduling_mode: int = 0 + input_layout: InputLayout = InputLayout.PACKED_QKV + cross_mha: int = 0 + alibi: bool = True + enable_attn_logit_softcapping: bool = False + return_softmax_stats: bool = False + enable_skip_softmax: bool = False + disabled_mask_types: Optional[Tuple[int]] = None + head_size_v: int = 0 + sage_block_sizes: Optional[Tuple[int, int, int]] = None + output_dtype: Optional[str] = None + is_mtp: bool = False + + +# BF16-QKV+BF16-out and BF16-Q + FP8-KV + BF16-out (or FP8-QKV+BF16-out) + + +def select_ldgsts( + sm: int, warp_specialization: bool, head_size: int, dtype: str +) -> Tuple[bool, bool, bool]: + if warp_specialization: + return (False, False, False) + elif sm == 120: + if dtype in ["fp16", "bf16"]: + # Need ldgsts (cp.async) for head_size > 64 to enable the tiled noloop + # kernel which handles RELOAD_Q (D > CTA_P_TILE_K=64). + if head_size <= 64: + return (False, False, False) + ldgsts_q = True + ldgsts_k = True + ldgsts_v = True + if head_size >= 256: + ldgsts_k = False + ldgsts_v = False + if head_size > 256: + ldgsts_q = False + return (ldgsts_q, ldgsts_k, ldgsts_v) + elif dtype == "e4m3": + return (False, False, False) + return (False, False, False) + + +def generate_kernel_spec( + sm: int, + head_size: int, + dtype: str, + enable_skip_softmax: Optional[bool] = False, + return_softmax: Optional[bool] = False, + enable_attn_logit_softcapping: Optional[bool] = False, + alibi: Optional[bool] = True, + is_mla: Optional[bool] = False, + head_size_v: Optional[int] = 0, + input_layout: Optional[InputLayout] = InputLayout.Q_PAGED_KV, + output_dtype: Optional[str] = None, +) -> FMHAv2KernelSpec: + """ + Generate a kernel spec for FMHAv2. + + Args: + sm: GPU SM version (90, 120) + head_size: Q/K head dimension + dtype: Data type ("fp16", "bf16", "e4m3", "e4m3_fp32") + return_softmax: Return softmax statistics + enable_attn_logit_softcapping: Enable logit softcapping + enable_skip_softmax: Enable Skip-Softmax (Sparse Attention) + alibi: Enable ALiBi positional encoding + is_mla: MLA mode (different head sizes for Q/K and V) + head_size_v: V head dimension (0 = same as head_size) + input_layout: Input layout enum + output_dtype: Output dtype string + """ + # Initialize spec with required fields (no class defaults) + # and user-provided optional fields + spec: dict[str, Any] = { + # Required fields + "sm": sm, + "dtype": dtype, + "seq_len": 0, + "head_size": head_size, + "warps_m": 4, + "warps_n": 1, + "version": 2, + "interleaved": False, + "share_smem_k_v": False, + "unroll_threshold": 1, + "has_scale_max": False, + # head_interleaved=False means input layout [tokens, 3, H, D] (not [tokens, H, 3, D]) + # This matches the Python API's expected format and TRT-LLM convention + "head_interleaved": False, + # User-provided values (override class defaults if different) + "input_layout": input_layout, + "alibi": alibi, + "enable_attn_logit_softcapping": enable_attn_logit_softcapping, + "return_softmax_stats": return_softmax, + "enable_skip_softmax": enable_skip_softmax, + "head_size_v": head_size_v, + "output_dtype": output_dtype, + "is_mtp": is_mla, + } + + # Compute ldgsts flags + warp_specialization = sm == 90 and head_size >= 32 + ldgsts_q, ldgsts_k, ldgsts_v = select_ldgsts( + sm, warp_specialization, head_size, dtype + ) + spec["ldgsts_q"] = ldgsts_q + spec["ldgsts_k"] = ldgsts_k + spec["ldgsts_v"] = ldgsts_v + + # Override class defaults that always differ + spec["flash_attention"] = True # Class default is False + spec["scheduling_mode"] = 1 # Class default is 0 + + # # SM-specific configuration + # if warp_specialization: + # spec["warp_specialization"] = True + # spec["sm_mma"] = 90 + # spec["loop_step"] = 64 + # spec["has_noloop"] = 0 + # spec["noloop_step"] = 64 + # spec["kv_tile_buffers"] = 2 # Class default is 1 + + # if dtype in ["fp16", "bf16"]: + # if head_size <= 64: + # spec["kv_loop_step"] = 256 + # elif head_size <= 128: + # spec["kv_loop_step"] = 128 + # # else: use class default 64 + # elif dtype == "e4m3": + # if head_size <= 64: + # spec["kv_tile_buffers"] = 4 + # if head_size <= 128: + # spec["kv_loop_step"] = 256 + # else: + # spec["kv_loop_step"] = 128 + # else: + # raise ValueError(f"Unsupported dtype: {dtype}") + + # SM-specific configuration + # + # Shared memory budget for H100: 228KB (232,448 bytes). + # Kernel uses __launch_bounds__(THREADS, 1) to claim max smem. + # + # FP16/BF16 smem layout (kernel_traits.h): + # smem_q[2] = 2 * D * STEP_Q * Q_BUF * 2B + # smem_k = D * STEP_KV * KV_BUF * 2B + # smem_v = DV * STEP_KV * KV_BUF * 2B + # + # FP8 smem layout (adds V scratch for QGMMA transpose): + # smem_q[2] = 2 * D * STEP_Q * Q_BUF * 1B + # smem_k = D * STEP_KV * KV_BUF * 1B + # smem_v = DV * STEP_KV * KV_BUF * 1B + # smem_v_scratch = DV * STEP_KV * 1B + # + # D is padded: D = min(round_up(head_size, 128/elem_bytes), next_pow2(head_size)) + # FP16: round_up to multiples of 64 -> D in {32, 64, 128, 192, 256} + # FP8: round_up to multiples of 128 -> D in {32, 64, 128, 256} + # (head_size=160 pads to D=256 for FP8 vs D=192 for FP16) + # + if warp_specialization: + spec["warp_specialization"] = True + spec["sm_mma"] = 90 + spec["loop_step"] = 64 + spec["has_noloop"] = 0 + spec["noloop_step"] = 64 + spec["kv_tile_buffers"] = 2 + + if dtype in ["fp16", "bf16"]: + if head_size <= 64: + # D<=64: smem = 16 + 96 + 96 = ~208KB with KV_BUF=3 (fits 228KB) + spec["kv_loop_step"] = 256 + spec["kv_tile_buffers"] = 3 + elif head_size <= 128: + # D=128: smem = 32 + 64 + 64 = ~160KB with KV_BUF=2 + spec["kv_loop_step"] = 128 + else: + # D>=192: smem = 48 + 48 + 48 = ~144KB (D=192) + # smem = 64 + 64 + 64 = ~192KB (D=256) + spec["kv_loop_step"] = 64 + elif dtype == "e4m3": + if head_size <= 64: + # D<=64: smem = 8 + 64 + 64 + 16 = ~152KB with KV_BUF=4 + # Deep pipeline hides FP8 V transpose latency (dma.h:598-672) + spec["kv_tile_buffers"] = 4 + if head_size <= 128: + # D<=128: smem = 16 + 64 + 64 + 32 = ~176KB with KV_BUF=2 + # Note: STEP_KV=256 causes BMM2_K_GROUPS=2 (kernel_traits.h:241) + # and V transpose unroll drops to 1 (dma.h:102), but fewer KV + # loop iterations outweighs per-iteration overhead for long seqs + spec["kv_loop_step"] = 256 + else: + # D=256 (FP8 pads head_size>128 to 256 due to 128-byte alignment): + # smem = 32 + 64 + 64 + 32 = ~192KB with KV_BUF=2 + spec["kv_loop_step"] = 128 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + elif sm == 120: + spec["sm_mma"] = 80 + spec["has_noloop"] = 1 + spec["noloop_step"] = 64 + spec["loop_step"] = 64 + + if dtype in ["fp16", "bf16"]: + if head_size <= 256: + q_loop_step = 64 + spec["kv_loop_step"] = 64 + elif head_size <= 512: + q_loop_step = 64 + # kv_loop_step uses class default 64 + else: + raise ValueError(f"Unsupported head size: {head_size}") + spec["noloop_step"] = q_loop_step + spec["loop_step"] = q_loop_step + # Granular tiling: runtime sets use_granular_tiling=true for SM>=80 + # flash attention, so we must generate _nl_tiled dispatch entries. + spec["tiled"] = 1 + elif dtype == "e4m3": + if is_mla: + # MLA kernels (TODO) + pass + else: + if head_size <= 64: + q_loop_step = 64 + spec["kv_loop_step"] = 64 + elif head_size <= 256: + q_loop_step = 64 + spec["kv_loop_step"] = 32 + else: + q_loop_step = 64 + spec["loop_step"] = q_loop_step + spec["noloop_step"] = q_loop_step + + elif sm == 90: + raise ValueError("(jimmyzho): Only Warp Specialization is supported for SM 90") + + return FMHAv2KernelSpec(**spec) + + +def is_kernel_spec_valid(kspec: FMHAv2KernelSpec) -> bool: + if kspec.alibi and kspec.enable_attn_logit_softcapping: + return False + + # Standard flash attention support + flash_valid: bool = ( + kspec.sm in [80, 86, 89, 90, 120] + and kspec.dtype in ["fp16", "bf16", "fp16_fp32", "e4m3", "e4m3_fp32"] + and kspec.head_size <= 256 + and kspec.head_size_v == 0 + and kspec.sage_block_sizes is None + and kspec.version == 2 + and not kspec.cross_mha + and kspec.flash_attention + and kspec.input_layout != InputLayout.SEPARATE_Q_K_V + ) + # SM90 non-flash ldgsts support (fixed seq len) + non_flash_valid: bool = ( + kspec.sm == 90 + and kspec.dtype in ["fp16", "bf16", "fp16_fp32"] + and kspec.head_size <= 256 + and bool(kspec.ldgsts_q) + and kspec.version == 2 + and not kspec.cross_mha + and not kspec.flash_attention + ) + # Clip/SigLip support + clip_valid: bool = ( + kspec.sm == 100 + and kspec.dtype in ["fp16", "bf16", "fp16_fp32", "e4m3", "e4m3_fp32"] + and kspec.head_size == 80 + and kspec.head_size_v == 0 + and kspec.sage_block_sizes is None + and kspec.version == 2 + and not kspec.cross_mha + and kspec.flash_attention + and kspec.input_layout != InputLayout.SEPARATE_Q_K_V + ) + # Deepseek MLA (generation 576/512 paged) + mla_valid_576_512: bool = ( + kspec.sm in [90, 100, 120] + and kspec.dtype in ["bf16", "e4m3_fp32"] + and kspec.head_size == 576 + and kspec.head_size_v == 512 + and kspec.input_layout == InputLayout.Q_PAGED_KV + and kspec.sage_block_sizes is None + and kspec.version == 2 + and not kspec.cross_mha + and kspec.flash_attention + and not kspec.warp_specialization + and bool(kspec.tiled) + ) + # Deepseek MLA (context 192/128 separate-q-k-v) + mla_valid_192_128: bool = ( + kspec.sm in [90, 100, 120] + and kspec.dtype in ["bf16", "e4m3", "e4m3_fp32"] + and kspec.head_size == 192 + and kspec.head_size_v == 128 + and kspec.input_layout == InputLayout.SEPARATE_Q_K_V + and kspec.sage_block_sizes is None + and kspec.version == 2 + and not kspec.cross_mha + and kspec.flash_attention + and ( + (kspec.warp_specialization and not kspec.alibi) # sm90 + or (not kspec.warp_specialization and bool(kspec.tiled)) + ) # non-sm90 + and not kspec.enable_attn_logit_softcapping + ) + # SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask) + sage_valid_sm90: bool = ( + kspec.sm == 90 + and kspec.head_size in [80, 128] + and kspec.version == 2 + and kspec.sage_block_sizes in [(64, 64, 256)] + and not kspec.cross_mha + and kspec.flash_attention + and kspec.warp_specialization + and kspec.input_layout == InputLayout.PACKED_QKV + and not kspec.alibi + and not kspec.enable_attn_logit_softcapping + ) + # SageAttention on Ada (head_size in (80, 128), packed QKV, padding mask) + sage_valid_sm89: bool = ( + kspec.sm == 89 + and kspec.head_size in [80, 128] + and kspec.sage_block_sizes in [(64, 32, 32)] + and kspec.output_dtype in ["fp16", "bf16"] + and kspec.version == 2 + and not kspec.cross_mha + and kspec.flash_attention + and not kspec.warp_specialization + and kspec.input_layout == InputLayout.PACKED_QKV + ) + # SM90 warp-specialized flash attention with SEPARATE_Q_K_V layout + # Supports standard attention (head_size_v == 0 means same as head_size) + flash_separate_qkv_valid: bool = ( + kspec.sm == 90 + and kspec.dtype in ["fp16", "bf16"] + and kspec.head_size <= 256 + and kspec.head_size_v == 0 + and kspec.sage_block_sizes is None + and kspec.version == 2 + and not kspec.cross_mha + and kspec.flash_attention + and kspec.warp_specialization + and kspec.input_layout == InputLayout.SEPARATE_Q_K_V + and not kspec.enable_attn_logit_softcapping + ) + + return ( + flash_valid + or non_flash_valid + or clip_valid + or mla_valid_576_512 + or mla_valid_192_128 + or sage_valid_sm90 + or sage_valid_sm89 + or flash_separate_qkv_valid + ) + + +def get_kernel_code(kspec: FMHAv2KernelSpec, kname: str, lname: str) -> Optional[str]: + min_cuda_version = 0 # no restriction + + # The architecture that determines the instruction. + effective_sm, sm_name = get_effective_sm_and_name(kspec) + + if effective_sm >= 80: + min_cuda_version = 11000 + + launcher_name = lname + causal_kernel_name = kname.replace("__placeholder__", "_causal") + custom_mask_kernel_name = kname.replace("__placeholder__", "_custom_mask") + sliding_or_chunked_causal_kernel_name = kname.replace( + "__placeholder__", "_sliding_or_chunked_causal" + ) + kernel_name = kname.replace("__placeholder__", "") + + # FIXME: use separate parameters when generating cubins for trtllm. + if not kspec.cross_mha: + params_type = "bert::Fused_multihead_attention_params_v{}".format(kspec.version) + else: + params_type = "bert::Fused_multihead_attention_params_mhca" + + if effective_sm < 90: + instruction_traits = sm_name.capitalize() + "_" + dtype2traits[kspec.dtype] + elif effective_sm == 90: + instruction_traits = ( + sm_name.capitalize() + "_" + hopper_dtype2traits[kspec.dtype] + ) + # for hopper, we differentiate instruction_traits_o and instruction_traits_p + instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits( + instruction_traits, kspec + ) + + if effective_sm < 90: + if kspec.flash_attention: + kernel_variant = "flash_attention" + else: + kernel_variant = "1xN" if kspec.warps_m == 1 else "2x2" + elif effective_sm == 90: + if kspec.warps_n > 1: + # for hopper we slice the problem along the M dim. + kernel_variant = "4xN" + "_hopper" + else: + kernel_variant = "4x1" + "_hopper" + + if effective_sm < 90: + kernel_traits = "Kernel_traits_" + elif effective_sm == 90: + kernel_traits = "FMHA_kernel_traits_hopper_" + + if kspec.interleaved: + kernel_traits += "interleaved_v2" + elif kspec.cross_mha: + kernel_traits += "fmhca" + else: + kernel_traits += "v{}".format(kspec.version) + + # decide whether to paged_kv kernel traits for ampere-style kernels. + if effective_sm < 90: + if kspec.input_layout == InputLayout.Q_PAGED_KV: + kernel_traits += "_paged_kv_cache" + elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV: + kernel_traits += "_contiguous_kv_cache" + elif kspec.input_layout == InputLayout.SEPARATE_Q_K_V: + kernel_traits += "_q_k_v" + + flags = 0 + if kspec.ldgsts_q: + flags |= 1 + if kspec.ldgsts_k: + flags |= 2 + if kspec.ldgsts_v: + flags |= 4 + if kspec.share_smem_k_v and not kspec.limit_qk_fragments: + flags |= 8 + if kspec.has_scale_max: + flags |= 16 + if not kspec.head_interleaved: + flags |= 32 + if kspec.limit_qk_fragments: + flags |= 128 + if kspec.limit_v_fragments: + flags |= 256 + if kspec.has_noloop: + # NOTE do not use flags 512 = 0x200 as it is reserved; do not add to flags because it + # will be selectively added to no-loop kernel trait upon generating .cu templates + pass + if kspec.enable_attn_logit_softcapping: + flags |= 2048 + if kspec.tiled: + flags |= 4096 + if kspec.is_mtp: + flags |= 8192 + + # only generate certain needed combinations of input_layout and mask types for trt-llm. + padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask = ( + selected_mask_types(kspec) + ) + + if any( + selected_mask_flag == "1" for selected_mask_flag in selected_mask_types(kspec) + ): + padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask = ( + selected_mask_types(kspec) + ) + else: + return None + + kernel_flags = "0x{:02x}u".format(flags) + + heads_interleaved_flag = pythonBoolean2cpp[kspec.head_interleaved] + + disable_fadd_trick = ( + 1 if effective_sm >= 86 else 0 + ) # this will force generating F2IP + + enable_mutex_flag = enable_mutex(kspec) + + has_alibi = pythonBoolean2cpp[kspec.alibi] + + input_layout_flag = str(int(kspec.input_layout)) + + run_fct_name = ( + "run_packed_qkv" + if kspec.input_layout == InputLayout.PACKED_QKV + else "run_separate_q_and_kv" + ) + + dma_reg_count, compute_reg_count = get_reg_count(kspec) + + use_tma_store_flag = enable_tma_store(kspec) + + enable_attn_logit_softcapping_flag = pythonBoolean2cpp[ + kspec.enable_attn_logit_softcapping + ] + + return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats] + + enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax] + + # needed by warpspec kernels. + fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"] + kernel_traits_header = ( + "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" + if fp8_kernel + else f"fmha::ws::Kernel_traits::Host dma_host; + dma_host.init_params(params, launch_params, stream); + """ + params_str = "params" + attn_mask_type_str = "using Attention_mask_type = fmha::Attention_mask_type;" + bert_launch_params = ( + "using Launch_params = bert::Fused_multihead_attention_launch_params;" + ) + include_str = "" + num_compute_groups_str = "static constexpr int NUM_COMPUTE_GROUPS = 2;" + fused_multihead_attention_params_v2_str = f"{params_type}" + const_fused_multihead_attention_params_v2_str = f"const {params_type}" + setmaxnreg_dma_str = r""" + const int DMA_REG_COUNT = {dma_reg_count}; + asm volatile("{{setmaxnreg.dec.sync.aligned.u32 %0; \n\t}}" ::"n"(DMA_REG_COUNT));""".format( + dma_reg_count=dma_reg_count + ) + setmaxnreg_compute_str = r""" + const int COMPUTE_REG_COUNT = {compute_reg_count}; + asm volatile("{{setmaxnreg.inc.sync.aligned.u32 %0; \n\t}}" ::"n"(COMPUTE_REG_COUNT));""".format( + compute_reg_count=compute_reg_count + ) + local_ns_open = "" + local_ns_close = "" + + tmp = dict(locals(), **asdict(kspec)) + + template_dir = jit_env.FLASHINFER_CSRC_DIR / "fmha_v2" / "templates" + if effective_sm < 90: + if kspec.flash_attention: + with open(template_dir / "fa_kernel.jinja", "r") as f: + template = jinja2.Template(f.read()) + + tmp["MAX_STGS_PER_LOOP"] = MAX_STGS_PER_LOOP + tmp["use_multi_cta"] = False + # RELOAD_Q is needed when D > CTA_P_TILE_K (64 for D >= 64 with + # granular tiling). The tiled noloop kernel handles this correctly + # but requires ldgsts (cp.async). + tmp["reload_q"] = kspec.tiled and kspec.head_size > 64 + code = template.render(tmp) + else: + with open(template_dir / "kernel.jinja", "r") as f: + template = jinja2.Template(f.read()) + tmp["MAX_STGS_PER_LOOP"] = MAX_STGS_PER_LOOP + use_multi_cta = 1 if kspec.ctas_per_head > 1 else 0 + tmp["use_multi_cta"] = use_multi_cta + code = template.render(tmp) + elif effective_sm == 90: + use_tma = 1 + if kspec.ldgsts_q: + use_tma = 0 + if kspec.warp_specialization: + with open(template_dir / "kernel_hopper_ws.jinja", "r") as f: + template = jinja2.Template(f.read()) + tmp["use_tma"] = use_tma + tmp["bytes_per_elt"] = dtype2bytes[kspec.dtype] + code = template.render(tmp) + else: + with open(template_dir / "kernel_hopper.jinja", "r") as f: + template = jinja2.Template(f.read()) + tmp["use_tma"] = use_tma + code = template.render(tmp) + else: + raise RuntimeError("No template found for this configuration.") + return code + + +def get_api_code(specs_names: list[Tuple[FMHAv2KernelSpec, str, str, str]]) -> str: + def get_signature(lname: str, version: int, cross_mha: int, use_tma: bool) -> str: + # The architecture that determines the instruction. + effective_sm, sm_name = get_effective_sm_and_name(kspec) + if cross_mha: + return "void {}(const Params_mhca ¶ms, cudaStream_t stream);".format( + lname + ) + elif effective_sm >= 90: + # need to set tma desc in params + return "void {}(Params_v{} ¶ms, const Launch_params &launch_params, cudaStream_t stream);".format( + lname, version + ) + else: + return "void {}(const Params_v{} ¶ms, const Launch_params &launch_params, cudaStream_t stream);".format( + lname, version + ) + + signatures = [] + for kspec, _fname, lname, _kname in specs_names: + effective_sm, _ = get_effective_sm_and_name(kspec) + use_tma = effective_sm == 90 and not kspec.ldgsts_q + signatures.append(get_signature(lname, kspec.version, kspec.cross_mha, use_tma)) + if kspec.has_noloop and not kspec.tiled: + signatures.append( + get_signature(lname + "_nl", kspec.version, kspec.cross_mha, use_tma) + ) + elif kspec.tiled: + signatures.append( + get_signature( + lname + "_nl_tiled", kspec.version, kspec.cross_mha, use_tma + ) + ) + if not kspec.warp_specialization: + signatures.append("void {}_get_max_heads_per_wave(int*);".format(lname)) + signatures_str = "\n".join(signatures) + + # v1 + # - normal + # - no loop + # v2 + # - normal + # - no loop + # - normal interleaved + # - no loop interleaved + # - flash attention no loop + # - flash attention no loop tiled + # - flash attention warp_specialized (on Hopper) + + def gen_unroll_check(kspec: FMHAv2KernelSpec) -> str: + code = "if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b > {unroll_threshold})))".format( + **asdict(kspec) + ) + if kspec.flash_attention: + code = "if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b * h > {unroll_threshold})))".format( + **asdict(kspec) + ) + return code + + def gen_call(kspec: FMHAv2KernelSpec, lname: str) -> str: + effective_sm, _ = get_effective_sm_and_name(kspec) + data_type = dtype2typename[kspec.dtype] + output_data_type = data_type + if kspec.output_dtype: + output_data_type = dtype2typename[kspec.output_dtype] + il_check = "" + if kspec.version == 2 and kspec.dtype in ["fp16", "bf16"]: + il_check += ( + "&& use_flash_attention " + if kspec.flash_attention + else "&& !use_flash_attention " + ) + if kspec.version == 2: + # attention input layout. + il_check += f"&& attention_input_layout == {kspec.input_layout.value} " + # interleaved layout or not. + il_check += "&& interleaved " if kspec.interleaved else "&& !interleaved " + if effective_sm == 90: + il_check += "&& !use_tma " if kspec.ldgsts_q else "&& use_tma " + il_check += ( + "&& warp_specialization " + if kspec.warp_specialization + else "&& !warp_specialization " + ) + else: + il_check += "&& !warp_specialization && !use_tma " + # Different accumulation types. + if "_fp32" in kspec.dtype or "bf16" in kspec.dtype or kspec.dtype == "e4m3": + il_check += "&& force_fp32_acc " + else: + il_check += "&& !force_fp32_acc " + # whether support alibi or not. + if kspec.warp_specialization: + il_check += ( + "&& params.has_alibi " if kspec.alibi else "&& !params.has_alibi " + ) + il_check += ( + "&& params.softmax_stats_ptr != nullptr " + if kspec.return_softmax_stats + else "&& params.softmax_stats_ptr == nullptr " + ) + # use enable_attn_logit_softcapping or not. + il_check += ( + "&& enable_attn_logit_softcapping " + if kspec.enable_attn_logit_softcapping + else "&& !enable_attn_logit_softcapping " + ) + # check sage block sizes + sage_block_size_q = 0 + sage_block_size_k = 0 + sage_block_size_v = 0 + if kspec.sage_block_sizes: + # override the data_type to output type, otherwise it is always E4M3 + data_type = output_data_type + sage_block_size_q = kspec.sage_block_sizes[0] + sage_block_size_k = kspec.sage_block_sizes[1] + sage_block_size_v = kspec.sage_block_sizes[2] + il_check += ( + f"&& sage_block_size_q == {sage_block_size_q} " + f"&& sage_block_size_k == {sage_block_size_k} " + f"&& sage_block_size_v == {sage_block_size_v} " + ) + + il_check += ( + "&& enable_skip_softmax " + if kspec.enable_skip_softmax + else "&& !enable_skip_softmax " + ) + + il_check += ( + "&& params.use_int8_scale_max " + if kspec.has_scale_max + else "&& !params.use_int8_scale_max " + ) + + slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0 + + ## NOTE: need to tune here + if kspec.has_noloop and not kspec.flash_attention: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && s == {slen} && d == {head_size} && sm == {sm} + {il_check}) {{ + + {unroll_check} {{ + if (fmha_v2_verbose) printf("[FMHAv2] kernel: {lname}\\n"); + {lname}(params, launch_params, stream); + }} else {{ + if (fmha_v2_verbose) printf("[FMHAv2] kernel: {lname}_nl\\n"); + {lname}_nl(params, launch_params, stream); + }} + +}} """.format( + **asdict(kspec), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + unroll_check=gen_unroll_check(kspec), + ) + + elif kspec.flash_attention: # NOTE: flash attention uses no_loop as default + # TypeError: got multiple values for keyword argument if using key 'head_size_v', so 'dv' instead + dv = kspec.head_size_v or kspec.head_size + if kspec.tiled: # higher precedence; does not require bh_upper_thres + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm} + {il_check} && use_tiled) {{ + + if (fmha_v2_verbose) printf("[FMHAv2] kernel: {lname}_nl_tiled\\n"); + {lname}_nl_tiled(params, launch_params, stream); +}} """.format( # type: ignore[str-format] + **asdict(kspec), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + dv=dv, + ) + # warp specialization kernels need launch_params + elif kspec.warp_specialization: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm} + {il_check}) {{ + + if (fmha_v2_verbose) printf("[FMHAv2] kernel: {lname}\\n"); + {lname}(params, launch_params, stream); +}} """.format( # type: ignore[str-format] + **asdict(kspec), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + dv=dv, + ) + else: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm} + && !use_tiled {il_check}) {{ + + if (fmha_v2_verbose) printf("[FMHAv2] kernel: {lname}_nl\\n"); + {lname}_nl(params, launch_params, stream); +}} """.format( # type: ignore[str-format] + **asdict(kspec), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + dv=dv, + ) + else: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && s == {slen} && d == {head_size} && sm == {sm} + {il_check}) {{ + + if (fmha_v2_verbose) printf("[FMHAv2] kernel: {lname}\\n"); + {lname}(params, launch_params, stream); +}} """.format( + **asdict(kspec), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + ) + return call_stmt + + def gen_call_fmhca(kspec: FMHAv2KernelSpec, lname: str) -> str: + effective_sm, _ = get_effective_sm_and_name(kspec) + data_type = dtype2typename[kspec.dtype] + il_check = "" + if kspec.version == 2: + il_check = "&& interleaved " if kspec.interleaved else "&& !interleaved " + if effective_sm == 90: + il_check += "&& !use_tma " if kspec.ldgsts_q else "&& use_tma " + il_check += ( + "&& params.use_int8_scale_max " + if kspec.has_scale_max + else "&& !params.use_int8_scale_max " + ) + + s_kv_len = kspec.seq_len + if kspec.has_noloop: + call_stmt = """\ +if( data_type == {data_type} && s_kv == {s_kv_len} && d == {head_size} && sm == {sm} {il_check}) {{ + + {unroll_check} {{ + {lname}(params, stream); + }} else {{ + {lname}_nl(params, stream); + }} + +}} """.format( + **asdict(kspec), + data_type=data_type, + s_kv_len=s_kv_len, + lname=lname, + il_check=il_check, + unroll_check=gen_unroll_check(kspec), + ) + + else: + call_stmt = """\ +if( data_type == {data_type} && s_kv == {s_kv_len} && d == {head_size} && sm == {sm} {il_check}) {{ + {lname}(params, stream); + }} """.format( + **asdict(kspec), + data_type=data_type, + s_kv_len=s_kv_len, + lname=lname, + il_check=il_check, + ) + return call_stmt + + calls_v2 = [ + gen_call(kspec, lname) + for kspec, fname, lname, kname in specs_names + if kspec.version == 2 and kspec.cross_mha == 0 + ] + + calls_v2_str = "else ".join(calls_v2) if len(calls_v2) > 0 else "if( false ) {}" + + calls_v1 = [ + gen_call(kspec, lname) + for kspec, fname, lname, kname in specs_names + if kspec.version == 1 and kspec.cross_mha == 0 + ] + + calls_v1_str = "else ".join(calls_v1) if len(calls_v1) > 0 else "if( false ) {}" + + calls_mhca = [ + gen_call_fmhca(kspec, lname) + for kspec, fname, lname, kname in specs_names + if kspec.cross_mha == 1 + ] + + calls_mhca_str = ( + "else ".join(calls_mhca) if len(calls_mhca) > 0 else "if( false ) {}" + ) + + def gen_warp_spec(kspec: FMHAv2KernelSpec) -> str: + data_type = dtype2typename[kspec.dtype] + if kspec.sage_block_sizes is not None: + assert kspec.output_dtype is not None + # override the data_type to output type, otherwise it is always E4M3 + data_type = dtype2typename[kspec.output_dtype] + slen = kspec.seq_len * kspec.ctas_per_head + effective_sm, _ = get_effective_sm_and_name(kspec) + warp_spec_check = "" + nl_warps_m = kspec.warps_m if effective_sm == 90 else 1 + nl_warps_n = ( + kspec.warps_n if effective_sm == 90 else kspec.warps_m * kspec.warps_n + ) + if kspec.version == 2 and kspec.dtype in ["fp16", "bf16"]: + warp_spec_check += ( + "&& use_flash_attention " + if kspec.flash_attention + else "&& !use_flash_attention " + ) + if kspec.version == 2: + if effective_sm == 90: + warp_spec_check += "&& !use_tma " if kspec.ldgsts_q else "&& use_tma " + warp_spec_check += ( + "&& warp_specialization " + if kspec.warp_specialization + else "&& !warp_specialization " + ) + else: + warp_spec_check += "&& !use_tma && !warp_specialization " + + if kspec.flash_attention: # NOTE support any sequence + return """\ +if( data_type == {data_type} && d == {head_size} && sm == {sm} {warp_spec_check} + && version == {version} ) {{ + warps_m = {warps_m}; + warps_n = {warps_n}; +}} """.format( # type: ignore[str-format] + **locals(), **asdict(kspec), unroll_check=gen_unroll_check(kspec) + ) + return """\ +if( data_type == {data_type} && s == {slen} && d == {head_size} && sm == {sm} {warp_spec_check} + && version == {version} ) {{ + {unroll_check} {{ + warps_m = {warps_m}; + warps_n = {warps_n}; + }} else {{ + warps_m = {nl_warps_m}; + warps_n = {nl_warps_n}; + }} +}} """.format(**locals(), **asdict(kspec), unroll_check=gen_unroll_check(kspec)) + + warp_specs = "else ".join([gen_warp_spec(spec[0]) for spec in specs_names]) + if len(warp_specs) > 0: + warp_specs += 'else {\n\tassert(false && "Unsupported config");\n}' + + # Generate the cta spec. + def gen_cta_spec(spec: Tuple[FMHAv2KernelSpec, str, str, str]) -> str: + kspec, _, lname, _ = spec + slen = kspec.seq_len * kspec.ctas_per_head + return """\ +if( data_type == {data_type} && s == {slen} && d == {head_size} && use_multi_ctas + && version == {version} ) {{ + + ctas_per_head = {ctas_per_head}; + {lname}_get_max_heads_per_wave(&max_heads_per_wave); + +}} """.format(**locals(), **asdict(kspec), data_type=dtype2typename[kspec.dtype]) + + cta_specs = "else ".join( + [gen_cta_spec(spec) for spec in specs_names if spec[0].ctas_per_head > 1] + ) + # pragma once + api_code = """\ +{copyright} + + +#include +#include +#include +#include +#include +#include + +using Params_v1 = bert::Fused_multihead_attention_params_v1; +using Params_v2 = bert::Fused_multihead_attention_params_v2; +using Params_mhca = bert::Fused_multihead_attention_params_mhca; +using Launch_params = bert::Fused_multihead_attention_launch_params; + +{signatures_str} + +inline void run_fmha_v1(Params_v1 ¶ms, + const Launch_params &launch_params, + Data_type data_type, + Data_type output_data_type, + int sm, + cudaStream_t stream=0){{ +const size_t s = params.s; +const size_t b = params.b; +const size_t d = params.d; +const bool force_unroll = launch_params.force_unroll; +const bool ignore_b1opt = launch_params.ignore_b1opt; + +const bool use_flash_attention = false; + +{calls_v1_str} +else {{ + assert(false && "Unsupported config."); +}} + +}} + +// Note: transitioning to moving kernel launch parameters into launch_params to reduce the +// occurrences the interface needs to be modified +inline void run_fmha_v2(Params_v2 ¶ms, + const Launch_params &launch_params, + Data_type data_type, + Data_type output_data_type, + int sm, + cudaStream_t stream=0) {{ + +const size_t s = params.s; +const size_t b = params.b; +const size_t h = params.h; +const size_t d = params.d; +const size_t dv = params.dv; +const size_t sage_block_size_q = params.sage.q.block_size; +const size_t sage_block_size_k = params.sage.k.block_size; +const size_t sage_block_size_v = params.sage.v.block_size; + +const bool interleaved = launch_params.interleaved; +const bool force_unroll = launch_params.force_unroll; +const bool ignore_b1opt = launch_params.ignore_b1opt; +const bool force_fp32_acc = launch_params.force_fp32_acc; +const bool warp_specialization = launch_params.warp_specialization; +const bool use_tma = launch_params.use_tma; +const bool use_flash_attention = launch_params.flash_attention; +const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping; +const bool enable_skip_softmax = launch_params.enable_skip_softmax; +const int attention_input_layout = static_cast(launch_params.attention_input_layout); +// tiled variant uses ldgsts +const bool use_tiled = launch_params.use_granular_tiling; + +static const bool fmha_v2_verbose = (std::getenv("FLASHINFER_FMHA_V2_VERBOSE") != nullptr); + +{calls_v2_str} +else {{ + assert(false && "Unsupported config."); +}} + +}} + +#if __guard_fmhca_placeholder__ // fmhca api header + +inline void run_fmhca(Params_mhca ¶ms, + const Launch_params &launch_params, + Data_type data_type, + int sm, + cudaStream_t stream=0) {{ + +const size_t s_kv = params.s; +const size_t b = params.b; +const size_t d = params.d_padded; + +const bool interleaved = launch_params.interleaved; +const bool force_unroll = launch_params.force_unroll; +const bool ignore_b1opt = launch_params.ignore_b1opt; + +{calls_mhca_str} +else {{ + assert(false && "Unsupported config"); +}} + +}} + +#endif // fmhca api header + +inline std::tuple get_warps(Launch_params& launch_params, + int sm, + Data_type data_type, + size_t s, + size_t b, + size_t d, + int version) {{ + size_t warps_m, warps_n, warps_k = 1; + const bool interleaved = launch_params.interleaved; + const bool use_tma = launch_params.use_tma; + const bool force_unroll = launch_params.force_unroll; + const bool ignore_b1opt = launch_params.ignore_b1opt; + const bool use_flash_attention = launch_params.flash_attention; + // tiled variant uses ldgsts + const bool use_tiled = launch_params.use_granular_tiling; + const bool warp_specialization = launch_params.warp_specialization; + +{warp_specs} + + return std::make_tuple(warps_m, warps_n, warps_k); +}} + +// The constant is defined in "setup.py". +constexpr int MAX_STGS_PER_LOOP = {MAX_STGS_PER_LOOP}; + +// The number of CTAs and threads per CTA to launch the kernel. +inline void get_grid_size(int &heads_per_wave, + int &ctas_per_head, + int sm, + Data_type data_type, + size_t b, + size_t s, + size_t h, + size_t d, + bool use_multi_ctas, + int version) {{ + + // Determine the number of CTAs per head (kernel constant). + int max_heads_per_wave = 0; + ctas_per_head = 1; + heads_per_wave = b*h; +{cta_specs} + + // Adjust the number of heads per wave. + if( heads_per_wave > max_heads_per_wave ) {{ + heads_per_wave = max_heads_per_wave; + }} +}} + +""".format(**locals(), copyright=copyright, MAX_STGS_PER_LOOP=MAX_STGS_PER_LOOP) + return api_code + + +def generate_jit_sources( + uri: str, + input_layout: str, + input_dtype: str, + output_dtype: Optional[str], + compilation_context: CompilationContext, +) -> list[pathlib.Path]: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + source_paths = [] + specs_names = [] + head_size_qk_values = [16, 32, 64, 128, 256, 512] + head_size_qk_warpspec_values = [32, 40, 48, 64, 80, 96, 104, 128, 160, 192, 256] + + # 0 means head_size_v = head_size_qk (required for flash_valid) + head_size_v_values = [0] + map_input_layout = { + "q_paged_kv_nhd": InputLayout.Q_PAGED_KV, + "q_paged_kv_hnd": InputLayout.Q_PAGED_KV, + "packed_qkv": InputLayout.PACKED_QKV, + "separate_q_k_v": InputLayout.SEPARATE_Q_K_V, + "contiguous_q_kv": InputLayout.CONTIGUOUS_Q_KV, + } + + input_layout_values = [map_input_layout[input_layout.lower()]] + dtype_values = [input_dtype] + output_dtype_values = [output_dtype] if output_dtype is not None else [None] + + is_mla_values = [False] + + enable_attn_logit_softcapping_values = [True, False] + return_softmax_values = [True, False] + alibi_values = [True, False] + target_major_archs = { + major for major, _minor in compilation_context.TARGET_CUDA_ARCHS + } + include_sm90_kernels = 9 in target_major_archs + include_sm120_kernels = 12 in target_major_archs + warp_spec_configs: itertools.product = itertools.product( + [90] if include_sm90_kernels else [], + dtype_values, + head_size_qk_warpspec_values, + head_size_v_values, + enable_attn_logit_softcapping_values, + return_softmax_values, + [False, True], # enable_skip_softmax + alibi_values, + is_mla_values, + input_layout_values, + output_dtype_values, + ) + + head_size_qk_sm120_values = [64, 128, 256] + sm120_configs: itertools.product = itertools.product( + [120] if include_sm120_kernels else [], + dtype_values, # fallback to avoid empty product + head_size_qk_sm120_values, + head_size_v_values, + enable_attn_logit_softcapping_values, + return_softmax_values, + [False], # no skip-softmax without warp specialization + alibi_values, + is_mla_values, + input_layout_values, + output_dtype_values, + ) + + other_configs: itertools.product = itertools.product( + [], + dtype_values, + head_size_qk_values, + head_size_v_values, + enable_attn_logit_softcapping_values, + return_softmax_values, + [False], # only warp-spec kernels support skip-softmax + alibi_values, + is_mla_values, + input_layout_values, + output_dtype_values, + ) + + config_lists = [other_configs] + if include_sm90_kernels: + config_lists.append(warp_spec_configs) + if include_sm120_kernels: + config_lists.append(sm120_configs) + + for config_list in config_lists: + for ( + sm_iter, + dtype_iter, + head_size_qk_iter, + head_size_v_iter, + enable_attn_logit_softcapping_iter, + return_softmax_iter, + enable_skip_softmax, + alibi_iter, + is_mla_iter, + input_layout_iter, + output_dtype_iter, + ) in config_list: + kspec = generate_kernel_spec( + sm=sm_iter, + head_size=head_size_qk_iter, + dtype=dtype_iter, + enable_skip_softmax=enable_skip_softmax, + return_softmax=return_softmax_iter, + enable_attn_logit_softcapping=enable_attn_logit_softcapping_iter, + alibi=alibi_iter, + is_mla=is_mla_iter, + input_layout=input_layout_iter, + head_size_v=head_size_v_iter, + output_dtype=output_dtype_iter, + ) + if not is_kernel_spec_valid(kspec): + continue + + fname, lname, kname = encode_name(kspec) + kernel_code = get_kernel_code(kspec, kname, lname) + if kernel_code is None: + continue + + # Write kernel source file + kernel_path = gen_directory / fname + write_if_different(kernel_path, kernel_code) + source_paths.append(kernel_path) + specs_names.append((kspec, fname, lname, kname)) + + api_code = get_api_code(specs_names) + api_path = gen_directory / "fmha_v2_api.h" + write_if_different(api_path, api_code) + return source_paths diff --git a/flashinfer/jit/attention/fmha_v2/generator_utils.py b/flashinfer/jit/attention/fmha_v2/generator_utils.py index d40307f393..8cede22036 100755 --- a/flashinfer/jit/attention/fmha_v2/generator_utils.py +++ b/flashinfer/jit/attention/fmha_v2/generator_utils.py @@ -107,79 +107,84 @@ class InputLayout(IntEnum): SEPARATE_Q_K_V = 3 -spec_fields = ( - "sm", - "dtype", - "seq_len", - "head_size", - "warps_m", - "warps_n", - "version", - "interleaved", - "ldgsts_q", - "ldgsts_k", - "ldgsts_v", - "share_smem_k_v", - "loop_step", - "has_noloop", - "noloop_step", - "unroll_threshold", - "has_scale_max", - "ctas_per_head", - "sm_mma", - "head_interleaved", - # new added fields (only used by flash attention implementation) - "flash_attention", - "kv_loop_step", - "flash_attention_bh_upper_threshold", # to deprecate; not actively used - "limit_qk_fragments", - "limit_v_fragments", - "tiled", - # fields for warp specialized kernel - "warp_specialization", - "q_tile_buffers", - "kv_tile_buffers", - "scheduling_mode", - # attention qkv input layout. - "input_layout", - # fused MHCA. - "cross_mha", - # other features - "alibi", - "enable_attn_logit_softcapping", - "return_softmax_stats", - "disabled_mask_types", - "head_size_v", - "sage_block_sizes", - "output_dtype", - "is_mtp", +kernel_spec = namedtuple( + "kernel_spec", + [ + "sm", + "dtype", + "seq_len", + "head_size", + "warps_m", + "warps_n", + "version", + "interleaved", + "ldgsts_q", + "ldgsts_k", + "ldgsts_v", + "share_smem_k_v", + "loop_step", + "has_noloop", + "noloop_step", + "unroll_threshold", + "has_scale_max", + "ctas_per_head", + "sm_mma", + "head_interleaved", + # new added fields (only used by flash attention implementation) + "flash_attention", + "kv_loop_step", + "flash_attention_bh_upper_threshold", # to deprecate; not actively used + "limit_qk_fragments", + "limit_v_fragments", + "tiled", + # fields for warp specialized kernel + "warp_specialization", + "q_tile_buffers", + "kv_tile_buffers", + "scheduling_mode", + # attention qkv input layout. + "input_layout", + # fused MHCA. + "cross_mha", + # other features + "alibi", + "enable_attn_logit_softcapping", + "return_softmax_stats", + "disabled_mask_types", + "head_size_v", + "sage_block_sizes", + "output_dtype", + "is_mtp", + "enable_skip_softmax", + ], + defaults=( + 1, # ctas_per_head + 1, # sm_mma + True, # head_interleaved + False, # flash_attention + 64, # kv_loop_step + -1, # flash_attention_bh_upper_threshold + False, # limit_qk_fragments + False, # limit_v_fragments + 0, # tiled + False, # warp_specialization + 1, # q_tile_buffers + 1, # kv_tile_buffers + 0, # scheduling_mode + InputLayout.PACKED_QKV, # input_layout + 0, # cross_mha + True, # alibi + False, # enable_attn_logit_softcapping + False, # return_softmax_stats + None, # disabled_mask_types + 0, # head_size_v + None, # sage_block_sizes + None, # output_dtype, same as dtype by default. + False, # is_mtp + False, # enable_skip_softmax + ), ) -kernel_spec = namedtuple("kernel_spec", spec_fields) # type: ignore[misc] -kernel_spec.__new__.__defaults__ = ( - 1, # ctas_per_head - 1, # sm_mma - True, # head_interleaved - False, # flash_attention - 64, # kv_loop_step - -1, # flash_attention_bh_upper_threshold - False, # limit_qk_fragments - False, # limit_v_fragments - 0, # tiled - False, # warp_specialization - 1, # q_tile_buffers - 1, # kv_tile_buffers - 0, # scheduling_mode - InputLayout.PACKED_QKV, - 0, # cross_mha - True, # alibi - False, # enable_attn_logit_softcapping - False, # return_softmax_stats - None, # disabled_mask_types - 0, # head size of V - None, # sage_block_sizes - None, # output_dtype, same as dtype by default. - False, -) # use MTP or not +spec_fields = kernel_spec._fields generate_cu_trtllm = os.environ.get("GENERATE_CU_TRTLLM", "False").lower() == "true" @@ -274,7 +279,7 @@ class InputLayout(IntEnum): """ -def get_makefile_code(specs_names): +def get_makefile_code(specs_names: list[tuple[kernel_spec, str, str, str]]) -> str: objects = "\n".join( [ "OBJECTS_MHA += obj/{}.o".format(fname) @@ -1470,6 +1475,7 @@ def get_makefile_code(specs_names): USE_TMA_STORE, {enable_attn_logit_softcapping_flag}, {return_softmax_stats_flag}, + {enable_skip_softmax_flag}, {output_dtype_}, {sage_block_size_q}, {sage_block_size_k}, @@ -1493,6 +1499,7 @@ def get_makefile_code(specs_names): USE_TMA_STORE, {enable_attn_logit_softcapping_flag}, {return_softmax_stats_flag}, + {enable_skip_softmax_flag}, {output_dtype_}>; using Ktraits_sliding_or_chunked_causal = {kernel_traits_header} @@ -1513,6 +1520,7 @@ def get_makefile_code(specs_names): USE_TMA_STORE && false, {enable_attn_logit_softcapping_flag}, {return_softmax_stats_flag}, + {enable_skip_softmax_flag}, {output_dtype_}>; using Ktraits_custom_mask = {kernel_traits_header} @@ -1533,6 +1541,7 @@ def get_makefile_code(specs_names): USE_TMA_STORE && false, {enable_attn_logit_softcapping_flag}, {return_softmax_stats_flag}, + {enable_skip_softmax_flag}, {output_dtype_}>; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1829,7 +1838,7 @@ def get_makefile_code(specs_names): """ -def encode_name(kernel_spec): +def encode_name(kernel_spec: kernel_spec) -> tuple[str, str, str]: effective_sm, sm_name = get_effective_sm_and_name(kernel_spec) # Is it a kernel for the interleaved NC/32HW32 INT8 layout? il_tag = "_il" if kernel_spec.interleaved else "" @@ -1870,6 +1879,8 @@ def encode_name(kernel_spec): if kernel_spec.enable_attn_logit_softcapping: feature_tags += "_softcapping" + if kernel_spec.enable_skip_softmax: + feature_tags += "_skipSoftmax" if kernel_spec.sage_block_sizes: feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}" if kernel_spec.output_dtype: @@ -1915,7 +1926,9 @@ def encode_name(kernel_spec): return fname, lname, kname -def get_GMMA_shape(instruction_traits, m, n, k, warps_n): +def get_GMMA_shape( + instruction_traits: str, m: int, n: int, k: int, warps_n: int +) -> tuple[int, int, int]: gmma_k = hopper_traits2shape[instruction_traits][-1] # gmma shape is 64xgmma_nx16, gmma_n should be as big as possible, but not bigger than n @@ -1936,13 +1949,13 @@ def get_GMMA_shape(instruction_traits, m, n, k, warps_n): return gmma_m, gmma_n, gmma_k -def enable_mutex(kspec): +def enable_mutex(kspec: kernel_spec) -> str: fp32_accu_dtype = kspec.dtype in ["fp16_fp32", "bf16"] enable_mutex = "false" if (fp32_accu_dtype or kspec.head_size <= 64) else "true" return enable_mutex -def enable_tma_store(kspec): +def enable_tma_store(kspec: kernel_spec) -> str: output_dtype = kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype # TMA copies data in the 16B granularity. return ( @@ -1952,7 +1965,7 @@ def enable_tma_store(kspec): ) -def get_reg_count(kspec): +def get_reg_count(kspec: kernel_spec) -> tuple[int, int]: # if kspec.paged_kv_input and kspec.dtype in ['fp16', 'fp16_fp32', 'bf16']: # dma_reg_count = 72 # compute_reg_count = 216 @@ -1965,7 +1978,9 @@ def get_reg_count(kspec): return dma_reg_count, compute_reg_count -def get_hopper_instruction_traits(instruction_traits, kernel_spec): +def get_hopper_instruction_traits( + instruction_traits: str, kernel_spec: kernel_spec +) -> tuple[str, str]: gmma_shape_p = get_GMMA_shape( instruction_traits, kernel_spec.loop_step, @@ -1988,7 +2003,7 @@ def get_hopper_instruction_traits(instruction_traits, kernel_spec): return instruction_traits_p, instruction_traits_o -def get_effective_sm_and_name(kspec): +def get_effective_sm_and_name(kspec: kernel_spec) -> tuple[int, str]: sm = kspec.sm # Override the mma instruction with an older one. if kspec.sm_mma in sm2name: @@ -2000,7 +2015,7 @@ def get_effective_sm_and_name(kspec): return sm, sm_name -def selected_mask_types(kspec): +def selected_mask_types(kspec: kernel_spec) -> tuple[str, str, str, str]: # by default, we generate all combinations. # '1' means true, '0' means false. padding_mask = "1" @@ -2055,7 +2070,7 @@ def selected_mask_types(kspec): return padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask -def get_kernel_code(kspec, kname, lname): +def get_kernel_code(kspec: kernel_spec, kname: str, lname: str) -> str | None: min_cuda_version = 0 # no restriction # The architecture that determines the instruction. @@ -2195,6 +2210,8 @@ def get_kernel_code(kspec, kname, lname): return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats] + enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax] + # needed by warpspec kernels. fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"] kernel_traits_header = ( @@ -2331,8 +2348,8 @@ def get_kernel_code(kspec, kname, lname): return code -def get_api_code(specs_names): - def get_signature(lname, version, cross_mha, use_tma): +def get_api_code(specs_names: list[tuple[kernel_spec, str, str, str]]) -> str: + def get_signature(lname: str, version: int, cross_mha: int, use_tma: bool) -> str: # The architecture that determines the instruction. effective_sm, sm_name = get_effective_sm_and_name(kspec) if cross_mha: @@ -2366,7 +2383,7 @@ def get_signature(lname, version, cross_mha, use_tma): ) if not kspec.warp_specialization: signatures.append("void {}_get_max_heads_per_wave(int*);".format(lname)) - signatures = "\n".join(signatures) + signatures = "\n".join(signatures) # type: ignore[assignment] # v1 # - normal @@ -2380,7 +2397,7 @@ def get_signature(lname, version, cross_mha, use_tma): # - flash attention no loop tiled # - flash attention warp_specialized (on Hopper) - def gen_unroll_check(kspec): + def gen_unroll_check(kspec: kernel_spec) -> str: code = "if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b > {unroll_threshold})))".format( **kspec._asdict() ) @@ -2390,7 +2407,7 @@ def gen_unroll_check(kspec): ) return code - def gen_call(kspec, lname): + def gen_call(kspec: kernel_spec, lname: str) -> str: effective_sm, _ = get_effective_sm_and_name(kspec) data_type = dtype2typename[kspec.dtype] output_data_type = data_type @@ -2454,6 +2471,12 @@ def gen_call(kspec, lname): f"&& sage_block_size_v == {sage_block_size_v} " ) + il_check += ( + "&& enable_skip_softmax " + if kspec.enable_skip_softmax + else "&& !enable_skip_softmax " + ) + il_check += ( "&& params.use_int8_scale_max " if kspec.has_scale_max @@ -2553,7 +2576,7 @@ def gen_call(kspec, lname): ) return call_stmt - def gen_call_fmhca(kspec, lname): + def gen_call_fmhca(kspec: kernel_spec, lname: str) -> str: effective_sm, _ = get_effective_sm_and_name(kspec) data_type = dtype2typename[kspec.dtype] il_check = "" @@ -2606,7 +2629,7 @@ def gen_call_fmhca(kspec, lname): if kspec.version == 2 and kspec.cross_mha == 0 ] - calls_v2 = "else ".join(calls_v2) if len(calls_v2) > 0 else "if( false ) {}" + calls_v2 = "else ".join(calls_v2) if len(calls_v2) > 0 else "if( false ) {}" # type: ignore[assignment] calls_v1 = [ gen_call(kspec, lname) @@ -2614,7 +2637,7 @@ def gen_call_fmhca(kspec, lname): if kspec.version == 1 and kspec.cross_mha == 0 ] - calls_v1 = "else ".join(calls_v1) if len(calls_v1) > 0 else "if( false ) {}" + calls_v1 = "else ".join(calls_v1) if len(calls_v1) > 0 else "if( false ) {}" # type: ignore[assignment] calls_mhca = [ gen_call_fmhca(kspec, lname) @@ -2622,9 +2645,9 @@ def gen_call_fmhca(kspec, lname): if kspec.cross_mha == 1 ] - calls_mhca = "else ".join(calls_mhca) if len(calls_mhca) > 0 else "if( false ) {}" + calls_mhca = "else ".join(calls_mhca) if len(calls_mhca) > 0 else "if( false ) {}" # type: ignore[assignment] - def gen_warp_spec(kspec): + def gen_warp_spec(kspec: kernel_spec) -> str: data_type = dtype2typename[kspec.dtype] if kspec.sage_block_sizes is not None: assert kspec.output_dtype is not None @@ -2680,7 +2703,7 @@ def gen_warp_spec(kspec): warp_specs += 'else {\n\tassert(false && "Unsupported config");\n}' # Generate the cta spec. - def gen_cta_spec(spec): + def gen_cta_spec(spec: tuple[kernel_spec, str, str, str]) -> str: kspec, _, lname, _ = spec slen = kspec.seq_len * kspec.ctas_per_head return """\ @@ -2759,6 +2782,7 @@ def gen_cta_spec(spec): const bool use_tma = launch_params.use_tma; const bool use_flash_attention = launch_params.flash_attention; const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping; +const bool enable_skip_softmax = launch_params.enable_skip_softmax; const int attention_input_layout = static_cast(launch_params.attention_input_layout); // tiled variant uses ldgsts const bool use_tiled = launch_params.use_granular_tiling; @@ -2862,7 +2886,7 @@ def gen_cta_spec(spec): """ -def get_kernel_traits_code(specs_names): +def get_kernel_traits_code(specs_names: list[tuple[kernel_spec, str, str, str]]) -> str: print_kernel_specs = [] for kspec, fname, lname, kname in specs_names: # noqa: B007 (fname, lname used via locals()) @@ -2941,6 +2965,8 @@ def get_kernel_traits_code(specs_names): kspec.enable_attn_logit_softcapping ] + enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax] + tmp = dict(locals(), **kspec._asdict()) if effective_sm < 90: @@ -3068,7 +3094,8 @@ def get_kernel_traits_code(specs_names): {input_layout_flag}, __use_tma_store__ /* USE_TMA_STORE */, {enable_attn_logit_softcapping_flag}, - {return_softmax_stats_flag}>; + {return_softmax_stats_flag}, + {enable_skip_softmax_flag}>; printf("%s %d %d %s %d %d\\n", \"{kname}\", @@ -3227,7 +3254,7 @@ def get_kernel_traits_code(specs_names): print_kernel_specs.append(snippet_ws_custom_mask) # remove none. print_kernel_specs = [spec for spec in print_kernel_specs if spec is not None] - print_kernel_specs = "\n".join(print_kernel_specs) + print_kernel_specs = "\n".join(print_kernel_specs) # type: ignore[assignment] code = ktraits_code_template.format(print_kernel_specs=print_kernel_specs) return code @@ -3236,20 +3263,37 @@ def get_kernel_traits_code(specs_names): # For now: # 1. Hopper head_size 128 kernel uses cubins for performance regressions. # 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed). +# 3. For skip-softmax attention feature, we force not to use cubins. # You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins. # This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins. -def use_cubin_header(sm, head_size, dtype): +def use_cubin_header( + sm: int, + head_size: int, + dtype: str, + output_dtype: str | None = None, + enable_skip_softmax: bool = False, +) -> bool: + if enable_skip_softmax: + return False + if "e4m3" in dtype and output_dtype in ["bf16", "fp16"]: + return False return (sm == 90 and head_size == 128) or (sm == 89 and "e4m3" in dtype) -def get_cubin_header(kernel_traits, specs_names): +def get_cubin_header( + kernel_traits: list[list[str]], specs_names: list[tuple[kernel_spec, str, str, str]] +) -> str: cubins = [] cubin_lens = [] - cubins_dict = {} - cubin_lens_dict = {} + cubins_dict: dict[int, list[str]] = {} + cubin_lens_dict: dict[int, list[str]] = {} for kspec, fname, lname, kname in specs_names: # noqa: B007 (lname, kname used via locals()) if generate_cu_trtllm and not use_cubin_header( - kspec.sm, kspec.head_size, kspec.dtype + kspec.sm, + kspec.head_size, + kspec.dtype, + kspec.output_dtype, + kspec.enable_skip_softmax, ): continue name = fname.replace(".", "_") @@ -3265,7 +3309,7 @@ def get_cubin_header(kernel_traits, specs_names): metadata_v1 = [] # Only metadata_v2 is used by TRT-LLM. metadata_v2 = [] - metadata_v2_dict = {} + metadata_v2_dict: dict[str, list[str]] = {} unroll_config_v1 = [] unroll_config_v2 = [] for kname, smem, threads, fname, unroll_step, unroll_threshold in kernel_traits: # noqa: B007 (smem, threads, unroll_threshold used via locals()) @@ -3291,6 +3335,7 @@ def get_cubin_header(kernel_traits, specs_names): .replace("ws_", "") .replace("softcapping_", "") .replace("sage_", "") + .replace("skipSoftmax_", "") .replace("output_", "") ) flash_attention = "flash_attention" in kname @@ -3317,7 +3362,7 @@ def get_cubin_header(kernel_traits, specs_names): toks.pop(-4) toks.pop(-3) else: - sage_block_sizes = (0, 0, 0) + sage_block_sizes = ["0", "0", "0"] head_size = toks[-3] if "x" in head_size: (head_size, head_size_v) = head_size.split("x") @@ -3389,6 +3434,9 @@ def get_cubin_header(kernel_traits, specs_names): sm != "90" or (sm == "90" and "_softmax" in kname) ] + enable_skip_softmax = "_skipSoftmax" in kname + enable_skip_softmax_flag = pythonBoolean2cpp[enable_skip_softmax] + # meta_unroll_step meta_unroll_step = unroll_step if ("_nl" in kname or "_ws" in kname) else "0" @@ -3413,7 +3461,13 @@ def get_cubin_header(kernel_traits, specs_names): if generate_cu_trtllm: def get_lname_from_kname(kname: str) -> str: - if use_cubin_header(int(sm), int(head_size), prec.lower()): + if use_cubin_header( + int(sm), + int(head_size), + prec.lower(), + output_prec.lower(), + enable_skip_softmax, + ): return "nullptr" lname = kname.replace("_kernel", "") mask_types = [ @@ -3434,15 +3488,21 @@ def get_lname_from_kname(kname: str) -> str: {sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \ {cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ -{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\ +{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\ """.format(**locals()) - if use_cubin_header(int(sm), int(head_size), prec.lower()) + if use_cubin_header( + int(sm), + int(head_size), + prec.lower(), + output_prec.lower(), + enable_skip_softmax, + ) else """\ {{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \ {sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \ 0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ -{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\ +{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\ """.format(**locals()) ) else: @@ -3451,7 +3511,7 @@ def get_lname_from_kname(kname: str) -> str: {sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \ {cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ -{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}}}\ +{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}}}\ """.format(**locals()) if sm in metadata_v2_dict: metadata_v2_dict[sm].append(code) @@ -3476,10 +3536,10 @@ def get_lname_from_kname(kname: str) -> str: else: raise AssertionError("Something terrible happened") - metadata_v1 = ",\n".join(metadata_v1) + metadata_v1 = ",\n".join(metadata_v1) # type: ignore[assignment] # Add macros to only include needed cubins during compilation. if bool(metadata_v2_dict): - metadata_v2 = "" + metadata_v2 = "" # type: ignore[assignment] for sm in metadata_v2_dict.keys(): macro_begin = f"#ifndef EXCLUDE_SM_{sm}" macro_end = "#endif\n\n" @@ -3487,19 +3547,19 @@ def get_lname_from_kname(kname: str) -> str: last_key = list(metadata_v2_dict.keys())[-1] metadata_v2 += ("" if sm == last_key else ",") + "\n" + macro_end else: - metadata_v2 = ",\n".join(metadata_v2) + metadata_v2 = ",\n".join(metadata_v2) # type: ignore[assignment] # Add macros to only include needed cubins during compilation. - for sm in cubins_dict.keys(): - macro_begin = f"#ifndef EXCLUDE_SM_{sm}" + for sm_key in cubins_dict.keys(): + macro_begin = f"#ifndef EXCLUDE_SM_{sm_key}" macro_end = "#endif\n" - cubins.extend([macro_begin] + cubins_dict[sm] + [macro_end]) - if sm in cubin_lens_dict: - cubin_lens.extend([macro_begin] + cubin_lens_dict[sm] + [macro_end]) - - unroll_config_v1 = ",\n".join(unroll_config_v1) - unroll_config_v2 = ",\n".join(unroll_config_v2) - cubins = "\n".join(cubins) - cubin_lens = "\n".join(cubin_lens) + cubins.extend([macro_begin] + cubins_dict[sm_key] + [macro_end]) + if sm_key in cubin_lens_dict: + cubin_lens.extend([macro_begin] + cubin_lens_dict[sm_key] + [macro_end]) + + unroll_config_v1 = ",\n".join(unroll_config_v1) # type: ignore[assignment] + unroll_config_v2 = ",\n".join(unroll_config_v2) # type: ignore[assignment] + cubins = "\n".join(cubins) # type: ignore[assignment] + cubin_lens = "\n".join(cubin_lens) # type: ignore[assignment] local_ns_open = ns_open local_ns_close = ns_close if generate_cu_trtllm else "}" launcher_line = ( @@ -3547,7 +3607,8 @@ def get_lname_from_kname(kname: str) -> str: bool mAlibiSupported; bool mTiled; bool mEnableAttnLogitSoftcapping; - bool mReturnSoftmaxStats;{launcher_line} + bool mReturnSoftmaxStats; + bool mEnableSkipSoftmax;{launcher_line} }} sMhaKernelMetaInfosV2[] = {{ {metadata_v2} }}; @@ -3608,6 +3669,7 @@ def get_lname_from_kname(kname: str) -> str: bool mTiled; bool mEnableAttnLogitSoftcapping; bool mReturnSoftmaxStats; + bool mEnableSkipSoftmax; }} metaV2[] = {{ {metadata_v2} }}; @@ -3619,11 +3681,11 @@ def get_lname_from_kname(kname: str) -> str: # This is used to add some kernels running in cubins for passing CI cases. -def modify_cubin_header(cubin_header): +def modify_cubin_header(cubin_header: str) -> str: result = cubin_header # for CI cases - def add_kernel_line(result, target, addition): + def add_kernel_line(result: str, target: str, addition: str) -> str: pos = result.find(target) if pos != -1: end_pos = result.find("\n", pos) @@ -3637,7 +3699,7 @@ def add_kernel_line(result, target, addition): extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len;""" result = add_kernel_line(result, target, addition) - def modify_kernel_line(result, target, new_line): + def modify_kernel_line(result: str, target: str, new_line: str) -> str: lines = result.split("\n") for i, line in enumerate(lines): if target in line: @@ -3646,7 +3708,7 @@ def modify_kernel_line(result, target, new_line): return "\n".join(lines) target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled" - new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},' + new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, false, nullptr},' result = modify_kernel_line(result, target, new_line) # make sure only one empty line at the end @@ -3658,7 +3720,7 @@ def modify_kernel_line(result, target, new_line): return "\n".join(lines) -def generate_files(specs_names): +def generate_files(specs_names: list[tuple[kernel_spec, str, str, str]]) -> None: kfiles = [] valid_specs_names = [] @@ -3722,9 +3784,9 @@ def generate_files(specs_names): "bin/print_traits.exe", stdin=subprocess.PIPE, stdout=subprocess.PIPE ) output, error = process.communicate() - output = output.decode("utf-8").strip() + output_str = output.decode("utf-8").strip() # this gives: kname, smem bytes, threads_per_cta, loop_step - kernel_traits = [traits.split() for traits in output.splitlines()] + kernel_traits = [traits.split() for traits in output_str.splitlines()] cubin_header = get_cubin_header(kernel_traits, valid_specs_names) if generate_cu_trtllm: cubin_header = modify_cubin_header(cubin_header) @@ -3733,7 +3795,7 @@ def generate_files(specs_names): f.write(cubin_header) -def enumerate_hgmma_tma_kernels(specs, sm=90): +def enumerate_hgmma_tma_kernels(specs: list[kernel_spec], sm: int = 90) -> None: specs.append( kernel_spec( sm=sm, @@ -3759,7 +3821,9 @@ def enumerate_hgmma_tma_kernels(specs, sm=90): # Note this will be used in TRT-LLM. -def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype="fp16"): +def enumerate_hgmma_ldgsts_kernels( + specs: list[kernel_spec], sm: int = 90, dtype: str = "fp16" +) -> None: for enable_attn_logit_softcapping in [False, True]: specs.append( kernel_spec( @@ -3811,12 +3875,17 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype="fp16"): # Note this will be used in TRT-LLM. -def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="fp16"): +def enumerate_hgmma_flash_warpspec_kernels( + specs: list[kernel_spec], + sm: int = 90, + dtype: str = "fp16", + enable_skip_softmax: bool = False, +) -> None: scheduling_mode = int(os.getenv("SCHEDULING_MODE", "1")) # use specialized kernels for cases without alibi scales. # there is a numeric issues when applying the exp2f scale optimization and alibi scale at the same time. - combinations = product( + combinations = product[tuple[bool, bool, InputLayout, bool]]( [False, True], [False, True], [ @@ -3876,6 +3945,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="fp16"): return_softmax_stats=return_softmax, scheduling_mode=scheduling_mode, input_layout=input_layout, + enable_skip_softmax=enable_skip_softmax, ) ) @@ -3909,6 +3979,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="fp16"): return_softmax_stats=return_softmax, scheduling_mode=scheduling_mode, input_layout=input_layout, + enable_skip_softmax=enable_skip_softmax, ) ) @@ -3942,6 +4013,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="fp16"): return_softmax_stats=return_softmax, scheduling_mode=scheduling_mode, input_layout=input_layout, + enable_skip_softmax=enable_skip_softmax, ) ) """ @@ -3993,8 +4065,13 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="fp16"): # Note this will be used in TRT-LLM. def enumerate_qgmma_flash_warpspec_kernels( - specs, sm=90, dtype="e4m3", sage_block_sizes=None, output_dtype=None -): + specs: list[kernel_spec], + sm: int = 90, + dtype: str = "e4m3", + sage_block_sizes: tuple[int, int, int] | None = None, + output_dtype: str | None = None, + enable_skip_softmax: bool = False, +) -> None: scheduling_mode = int(os.getenv("SCHEDULING_MODE", "1")) # use specialized kernels for cases without alibi scales. @@ -4060,6 +4137,7 @@ def enumerate_qgmma_flash_warpspec_kernels( input_layout=input_layout, sage_block_sizes=sage_block_sizes, output_dtype=output_dtype, + enable_skip_softmax=enable_skip_softmax, ) ) @@ -4096,6 +4174,7 @@ def enumerate_qgmma_flash_warpspec_kernels( input_layout=input_layout, sage_block_sizes=sage_block_sizes, output_dtype=output_dtype, + enable_skip_softmax=enable_skip_softmax, ) ) @@ -4132,6 +4211,7 @@ def enumerate_qgmma_flash_warpspec_kernels( input_layout=input_layout, sage_block_sizes=sage_block_sizes, output_dtype=output_dtype, + enable_skip_softmax=enable_skip_softmax, ) ) @@ -4174,7 +4254,7 @@ def enumerate_qgmma_flash_warpspec_kernels( ) -def enumerate_igmma_kernels(specs, sm=90): +def enumerate_igmma_kernels(specs: list[kernel_spec], sm: int = 90) -> None: specs.append( kernel_spec( sm=sm, @@ -4222,7 +4302,9 @@ def enumerate_igmma_kernels(specs, sm=90): ) -def enumerate_hmma_kernels(specs, sm=80, dtype="fp16"): +def enumerate_hmma_kernels( + specs: list[kernel_spec], sm: int = 80, dtype: str = "fp16" +) -> None: # The following kernels are hmma-based kernels tuned for sm90 if sm == 90: specs.append( @@ -4829,7 +4911,7 @@ def enumerate_hmma_kernels(specs, sm=80, dtype="fp16"): # - S=64 -def enumerate_hmma884_kernels(specs, sm=70): +def enumerate_hmma884_kernels(specs: list[kernel_spec], sm: int = 70) -> None: # - FP16 # - S=512: STEP=32, STEP NL=-- FLAGS=0x9 (0x9 for SM86!) specs.append( @@ -5044,14 +5126,18 @@ def enumerate_hmma884_kernels(specs, sm=70): ) -def enumerate_hmma_paged_kv_flash_kernels(specs, sm=80, dtype="fp16"): +def enumerate_hmma_paged_kv_flash_kernels( + specs: list[kernel_spec], sm: int = 80, dtype: str = "fp16" +) -> None: for enable_attn_logit_softcapping in [False, True]: enumerate_hmma_flash_kernels_base( specs, sm, dtype, InputLayout.PACKED_QKV, enable_attn_logit_softcapping ) -def enumerate_hmma_flash_kernels(specs, sm=80, dtype="fp16", head_size_v=0): +def enumerate_hmma_flash_kernels( + specs: list[kernel_spec], sm: int = 80, dtype: str = "fp16", head_size_v: int = 0 +) -> None: input_layouts = [ InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, @@ -5070,13 +5156,13 @@ def enumerate_hmma_flash_kernels(specs, sm=80, dtype="fp16", head_size_v=0): # Note this will be used in TRT-LLM. def enumerate_hmma_flash_kernels_base( - specs, - sm=80, - dtype="fp16", - input_layout=InputLayout.PACKED_QKV, - enable_attn_logit_softcapping=False, - head_size_v=0, -): + specs: list[kernel_spec], + sm: int = 80, + dtype: str = "fp16", + input_layout: InputLayout = InputLayout.PACKED_QKV, + enable_attn_logit_softcapping: bool = False, + head_size_v: int = 0, +) -> None: # - FP16 Flash Attention (use nl as default) # Any Sequence Length H = 16/32/40/48/64/80/128/160/256/512 flash attention @@ -5250,7 +5336,7 @@ def enumerate_hmma_flash_kernels_base( ) -def enumerate_qgmma_kernels(specs, sm=90): +def enumerate_qgmma_kernels(specs: list[kernel_spec], sm: int = 90) -> None: specs.append( kernel_spec( sm=sm, @@ -5298,7 +5384,7 @@ def enumerate_qgmma_kernels(specs, sm=90): ) -def enumerate_qmma_kernels(specs, sm=89): +def enumerate_qmma_kernels(specs: list[kernel_spec], sm: int = 89) -> None: # SM89 (Ada) fp8 # Head Size 64 @@ -5428,13 +5514,13 @@ def enumerate_qmma_kernels(specs, sm=89): def enumerate_qmma_flash_kernels( - specs, - sm=89, - dtype="e4m3_fp32", - head_sizes=None, - sage_block_sizes=None, - output_dtype=None, -): + specs: list[kernel_spec], + sm: int = 89, + dtype: str = "e4m3_fp32", + head_sizes: list[int] | None = None, + sage_block_sizes: tuple[int, int, int] | None = None, + output_dtype: str | None = None, +) -> None: # ((head_size, head_size_v), (q_loop_step, kv_loop_step), tiled). params_q_kv_step = [ (32, (128, 128), 0), @@ -5510,7 +5596,7 @@ def enumerate_qmma_flash_kernels( ) -def enumerate_imma_kernels(specs, sm=80): +def enumerate_imma_kernels(specs: list[kernel_spec], sm: int = 80) -> None: if sm == 90: # The following kernels are imma-based kernels tuned for sm90 specs.append( @@ -6296,7 +6382,7 @@ def enumerate_imma_kernels(specs, sm=80): # - S=64 -def enumerate_cross_mha_kernels(specs): +def enumerate_cross_mha_kernels(specs: list[kernel_spec]) -> None: # TODO: combine cross_mha and mha kernel enumeration # - S_Q=4096, S_KV=128: STEP=64, STEP NL=64 # HEAD_SIZE: 64 @@ -6659,11 +6745,11 @@ def enumerate_cross_mha_kernels(specs): ) -def enumerate_kernels(): +def enumerate_kernels() -> None: if not os.path.exists("./generated"): os.mkdir("./generated") - specs = [] + specs: list[kernel_spec] = [] # TODO we have to select the unroll_threshold over a grid of b and h for each arch @@ -6678,12 +6764,25 @@ def enumerate_kernels(): enumerate_igmma_kernels(specs, sm=90) enumerate_qgmma_kernels(specs, sm=90) # need to add bf16 kernels if needed - enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="fp16") - enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype="bf16") - enumerate_qgmma_flash_warpspec_kernels(specs, sm=90, dtype="e4m3") - enumerate_qgmma_flash_warpspec_kernels( - specs, sm=90, dtype="e4m3", output_dtype="bf16" - ) + for enable_skip_softmax in [False, True]: + if enable_skip_softmax and "DISABLE_SKIP_SOFTMAX" in os.environ: + continue + enumerate_hgmma_flash_warpspec_kernels( + specs, sm=90, dtype="fp16", enable_skip_softmax=enable_skip_softmax + ) + enumerate_hgmma_flash_warpspec_kernels( + specs, sm=90, dtype="bf16", enable_skip_softmax=enable_skip_softmax + ) + enumerate_qgmma_flash_warpspec_kernels( + specs, sm=90, dtype="e4m3", enable_skip_softmax=enable_skip_softmax + ) + enumerate_qgmma_flash_warpspec_kernels( + specs, + sm=90, + dtype="e4m3", + output_dtype="bf16", + enable_skip_softmax=enable_skip_softmax, + ) # For now SageAttention only needs BF16 # block_size_q should be divisible by 64 diff --git a/flashinfer/jit/attention/fmha_v2/utils.py b/flashinfer/jit/attention/fmha_v2/utils.py new file mode 100644 index 0000000000..2d8e0b6bb6 --- /dev/null +++ b/flashinfer/jit/attention/fmha_v2/utils.py @@ -0,0 +1,1013 @@ +import os +from collections import namedtuple +from enum import IntEnum +from dataclasses import asdict + +copyright = r"""/* +* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +""" + + +sm2name = { + 70: "volta", + 72: "volta", + 75: "turing", + 80: "ampere", + 86: "ampere", + 87: "ampere", + 89: "ada", + 90: "hopper", + 120: "blackwell", +} + +dtype2traits = { + "int8": "imma_int8_int32_traits", + "fp16": "hmma_fp16_traits", + "fp16_fp32": "hmma_fp32_traits", + "bf16": "hmma_bf16_traits", + "e4m3": "qmma_e4m3_fp32_traits", + "e4m3_fp32": "qmma_e4m3_fp32_traits", + "e4m3_fp16": "qmma_e4m3_fp16_traits", +} + +dtype2OutputType = { + "int8": "int8_t", + "fp16": "fp16_t", + "fp16_fp32": "fp16_t", + "bf16": "bf16_t", + "e4m3": "e4m3_t", + "e4m3_fp32": "e4m3_t", + "e4m3_fp16": "e4m3_t", +} + +dtype2bytes = { + "int8": 1, + "fp16": 2, + "fp16_fp32": 2, + "bf16": 2, + "e4m3": 1, + "e4m3_fp32": 1, + "e4m3_fp16": 1, +} + +# TODO merge with above? +hopper_dtype2traits = { + "int8": "igmma_int8_int32_traits", + "fp16": "hgmma_fp16_traits", + "fp16_fp32": "hgmma_fp32_traits", + "bf16": "hgmma_bf16_traits", + "e4m3": "qgmma_e4m3_fp32_traits", + "e4m3_fp32": "qgmma_e4m3_fp32_traits", +} + +# The minimal instruction shapes per warp group. +# TODO should this not be known to the trait itself? +hopper_traits2shape = { + "Hopper_igmma_int8_int32_traits": (64, 8, 32), + "Hopper_hgmma_fp16_traits": (64, 8, 16), + "Hopper_hgmma_fp32_traits": (64, 8, 16), + "Hopper_hgmma_bf16_traits": (64, 8, 16), + "Hopper_qgmma_e4m3_fp32_traits": (64, 8, 32), +} + +dtype2typename = { + "int8": "DATA_TYPE_INT8", + "fp16": "DATA_TYPE_FP16", + "fp16_fp32": "DATA_TYPE_FP16", + "bf16": "DATA_TYPE_BF16", + "e4m3": "DATA_TYPE_E4M3", + "e4m3_fp16": "DATA_TYPE_E4M3", + "e4m3_fp32": "DATA_TYPE_E4M3", +} + +pythonBoolean2cpp = {True: "true", False: "false"} + + +# same definition as fused_multihead_attention.h. +class AttentionMaskType(IntEnum): + PADDING = 0 + CAUSAL = 1 + SLIDING_OR_CHUNKED_CAUSAL = 2 + CUSTOM_MASK = 3 + + +class InputLayout(IntEnum): + PACKED_QKV = 0 + CONTIGUOUS_Q_KV = 1 + Q_PAGED_KV = 2 + SEPARATE_Q_K_V = 3 + + +spec_fields = ( + "sm", + "dtype", + "seq_len", + "head_size", + "warps_m", + "warps_n", + "version", + "interleaved", + "ldgsts_q", + "ldgsts_k", + "ldgsts_v", + "share_smem_k_v", + "loop_step", + "has_noloop", + "noloop_step", + "unroll_threshold", + "has_scale_max", + "ctas_per_head", + "sm_mma", + "head_interleaved", + # new added fields (only used by flash attention implementation) + "flash_attention", + "kv_loop_step", + "flash_attention_bh_upper_threshold", # to deprecate; not actively used + "limit_qk_fragments", + "limit_v_fragments", + "tiled", + # fields for warp specialized kernel + "warp_specialization", + "q_tile_buffers", + "kv_tile_buffers", + "scheduling_mode", + # attention qkv input layout. + "input_layout", + # fused MHCA. + "cross_mha", + # other features + "alibi", + "enable_attn_logit_softcapping", + "return_softmax_stats", + "enable_skip_softmax", + "disabled_mask_types", + "head_size_v", + "sage_block_sizes", + "output_dtype", + "is_mtp", +) + +kernel_spec = namedtuple("kernel_spec", spec_fields) # type: ignore[misc] +kernel_spec.__new__.__defaults__ = ( + 1, # ctas_per_head + 1, # sm_mma + True, # head_interleaved + False, # flash_attention + 64, # kv_loop_step + -1, # flash_attention_bh_upper_threshold + False, # limit_qk_fragments + False, # limit_v_fragments + 0, # tiled + False, # warp_specialization + 1, # q_tile_buffers + 1, # kv_tile_buffers + 0, # scheduling_mode + InputLayout.PACKED_QKV, + 0, # cross_mha + True, # alibi + False, # enable_attn_logit_softcapping + False, # return_softmax_stats + False, # enable_skip_softmax + None, # disabled_mask_types + 0, # head size of V + None, # sage_block_sizes + None, # output_dtype, same as dtype by default. + False, +) # use MTP or not + +generate_cu_trtllm = os.environ.get("GENERATE_CU_TRTLLM", "False").lower() == "true" + +ns_open = ( + r""" +namespace tensorrt_llm +{ +namespace kernels +{ +// clang-format off +""" + if generate_cu_trtllm + else "" +) + +ns_close = ( + r""" +// clang-format on +} // namespace kernels +} // namespace tensorrt_llm +""" + if generate_cu_trtllm + else "" +) + +copyright = ( + """\ +/*************************************************************************************************** + * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are not permit- + * ted. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +""" + if not generate_cu_trtllm + else r"""/* +* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & +* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +""" +) + + +MAX_STGS_PER_LOOP = 4 + + +def encode_name(kernel_spec): + effective_sm, sm_name = get_effective_sm_and_name(kernel_spec) + # Is it a kernel for the interleaved NC/32HW32 INT8 layout? + il_tag = "_il" if kernel_spec.interleaved else "" + # Is it using the quantization scaling factor as an approximation of the max in softmax? + scale_max_tag = "_scale_max" if kernel_spec.has_scale_max else "" + # Deal with multi-CTA kernels for which the sequence length is seq_len per CTA * # of CTAs. + seqlen = kernel_spec.seq_len * kernel_spec.ctas_per_head + # The qkv layout. + qkv_layout_tag = "" + if kernel_spec.input_layout == InputLayout.PACKED_QKV: + qkv_layout_tag = "_qkv" + elif kernel_spec.input_layout == InputLayout.Q_PAGED_KV: + qkv_layout_tag = "_q_paged_kv" + elif kernel_spec.input_layout == InputLayout.SEPARATE_Q_K_V: + qkv_layout_tag = "_q_k_v" + else: + qkv_layout_tag = "_q_kv" + # for SM90 kernels, let's also differentiate ldgsts and tma kernels + feature_tags = "" + if effective_sm == 90: + # let's think about where to insert tma/ldgsts in the string before MR. [Timmy] + if kernel_spec.ldgsts_q: + tma_or_ldgsts = "_ldgsts" + else: + tma_or_ldgsts = "_tma" + if kernel_spec.warp_specialization: + warp_specialization_tag = "_ws" + else: + warp_specialization_tag = "" + else: + tma_or_ldgsts = "" + warp_specialization_tag = "" + + # Add alibi and return_softmax_stats to feature_tags for all kernels + # to ensure unique filenames across all kernel configurations + if kernel_spec.alibi: + feature_tags += "_alibi" + if kernel_spec.return_softmax_stats: + feature_tags += "_softmax" + if kernel_spec.enable_attn_logit_softcapping: + feature_tags += "_softcapping" + if kernel_spec.enable_skip_softmax: + feature_tags += "_skipSoftmax" + if kernel_spec.sage_block_sizes: + feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}" + if kernel_spec.output_dtype: + feature_tags += f"_output_{kernel_spec.output_dtype}" + if kernel_spec.is_mtp: + feature_tags += "_mtp" + if kernel_spec.ctas_per_head > 1: + fmt = ( + "fmha_v{version}{il_tag}_{dtype}_" + + str(seqlen) + + "_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}" + ) + elif kernel_spec.flash_attention: + fmt = "fmha_v{version}{il_tag}_flash_attention_{dtype}_{loop_step}_{kv_loop_step}_S{qkv_layout_tag}_{head_size}{head_size_v_str}{attrib}{feature_tags}{scale_max_tag}{tma_or_ldgsts}{warp_specialization_tag}_sm{sm}" + elif kernel_spec.cross_mha: + fmt = "fmha_mhca_{dtype}_{seq_len}_{head_size}{scale_max_tag}{tma_or_ldgsts}_sm{sm}" + else: + fmt = "fmha_v{version}{il_tag}_{dtype}_{seq_len}_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}" + head_size_v_str = ( + "" if kernel_spec.head_size_v == 0 else f"x{kernel_spec.head_size_v}" + ) + # Assemble the name of the kernel. + name_base = fmt.format( + **asdict(kernel_spec), + head_size_v_str=head_size_v_str, + il_tag=il_tag, + qkv_layout_tag=qkv_layout_tag, + scale_max_tag=scale_max_tag, + tma_or_ldgsts=tma_or_ldgsts, + warp_specialization_tag=warp_specialization_tag, + feature_tags=feature_tags, + attrib="__placeholder__", + ) + + # Produce file, launch function and kernel names. + fname = name_base.replace("__placeholder__", "") + if seqlen >= 1024 and not kernel_spec.flash_attention: + fname += ".no_i2f_f2i" + fname += ".cu" + lname = ("run_" + name_base).replace("__placeholder__", "") + kname = name_base + "_kernel" + + # remove causal + fname = fname.replace("causal_", "") + return fname, lname, kname + + +def get_GMMA_shape(instruction_traits, m, n, k, warps_n): + gmma_k = hopper_traits2shape[instruction_traits][-1] + + # gmma shape is 64xgmma_nx16, gmma_n should be as big as possible, but not bigger than n + # gmma_n should also be smaller than 256 + gmma_m = 64 + gmma_n = 0 + # find the largest supported n + n_supported = [(i + 1) * 8 for i in range(32)][::-1] + n_target = n // warps_n + assert n_target * warps_n == n + assert n_supported[0] == 256 and n_supported[-1] == 8 + for cand_n in n_supported: + if n_target % cand_n == 0: + gmma_n = cand_n + break + assert gmma_n > 0, "No supported GMMA_N found!" + + return gmma_m, gmma_n, gmma_k + + +def enable_mutex(kspec): + # Mutex is needed for head_size > 64 to synchronize HGMMA operations between warp groups. + # This applies to all 2-byte element types (fp16, bf16) regardless of accumulation precision. + # enable_mutex = "false" if kspec.head_size <= 64 else "true" + fp32_accu_dtype = kspec.dtype in ["fp16_fp32", "bf16"] + enable_mutex = "false" if (fp32_accu_dtype or kspec.head_size <= 64) else "true" + return enable_mutex + + +def enable_tma_store(kspec): + output_dtype = kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype + # TMA copies data in the 16B granularity. + return ( + "true" + if (output_dtype in ["e4m3", "e4m3_fp32"] and kspec.head_size % 16 == 0) + else "false" + ) + + +def get_reg_count(kspec): + # if kspec.paged_kv_input and kspec.dtype in ['fp16', 'fp16_fp32', 'bf16']: + # dma_reg_count = 72 + # compute_reg_count = 216 + if kspec.input_layout == InputLayout.Q_PAGED_KV: + dma_reg_count = 56 + compute_reg_count = 224 + else: + dma_reg_count = 40 + compute_reg_count = 232 + return dma_reg_count, compute_reg_count + + +def get_hopper_instruction_traits(instruction_traits, kernel_spec): + gmma_shape_p = get_GMMA_shape( + instruction_traits, + kernel_spec.loop_step, + kernel_spec.seq_len, + kernel_spec.head_size, + kernel_spec.warps_n, + ) + + instruction_traits_p = f"{instruction_traits}<{', '.join([str(x) for x in gmma_shape_p])}, false, false>" + + gmma_shape_o = get_GMMA_shape( + instruction_traits, + kernel_spec.loop_step, + kernel_spec.head_size, + kernel_spec.seq_len, + 1, + ) + instruction_traits_o = f"{instruction_traits}<{', '.join([str(x) for x in gmma_shape_o])}, true, false>" + + return instruction_traits_p, instruction_traits_o + + +def get_effective_sm_and_name(kspec): + sm = kspec.sm + # Override the mma instruction with an older one. + if kspec.sm_mma in sm2name: + assert kspec.sm_mma <= kspec.sm, ( + "Instruction version should be at most target arch" + ) + sm = kspec.sm_mma + sm_name = sm2name[sm] + return sm, sm_name + + +def selected_mask_types(kspec): + # by default, we generate all combinations. + # '1' means true, '0' means false. + padding_mask = "1" + causal_mask = "1" + sliding_or_chunked_causal_mask = "1" + custom_mask = "1" + # only generate certain needed combinations of input_layout and mask types for trt-llm. + if "GENERATE_CUBIN" in os.environ: + if kspec.sage_block_sizes: + # SageAttention only needs padding mask now + causal_mask = "0" + sliding_or_chunked_causal_mask = "0" + custom_mask = "0" + elif (kspec.head_size, kspec.head_size_v) == (192, 128): + # MLA context phase only needs causal mask and padding mask (for chunked prefill) now + sliding_or_chunked_causal_mask = "0" + custom_mask = "0" + elif (kspec.head_size, kspec.head_size_v) == (576, 512): + # MLA generation phase only needs padding mask (MtpMask) now + causal_mask = "0" + sliding_or_chunked_causal_mask = "0" + custom_mask = "0" + # encoder models (head_size = 32 / 64 / 128) need packed_qkv input layout + padding mask. + elif kspec.input_layout == InputLayout.PACKED_QKV: + # NOTE: 72/80 are added for vision transformer + if kspec.head_size not in [32, 64, 72, 80, 128]: + padding_mask = "0" + # only cross attention (head_size = 32/64/128) needs contiguous_q_kv input layout + padding mask / custom_mask. + elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV: + causal_mask = "0" + sliding_or_chunked_causal_mask = "0" + if kspec.head_size not in [32, 64, 72, 128]: + padding_mask = "0" + custom_mask = "0" + # paged kv cache is always needed in gpt variants. + # cross-attention also needs paged kv cache. + elif kspec.input_layout == InputLayout.Q_PAGED_KV: + if kspec.head_size not in [32, 64, 128]: + padding_mask = "0" + + # alibi specialized kernels only need causal mask. + if kspec.alibi and kspec.warp_specialization: + padding_mask = "0" + sliding_or_chunked_causal_mask = "0" + custom_mask = "0" + + # enable_attn_logit_softcapping kernels only need causal mask or sliding_or_chunked_causal_mask. + if kspec.enable_attn_logit_softcapping: + padding_mask = "0" + custom_mask = "0" + + return padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask + + +def get_api_code(specs_names): + def get_signature(lname, version, cross_mha, use_tma): + # The architecture that determines the instruction. + effective_sm, sm_name = get_effective_sm_and_name(kspec) + if cross_mha: + return "void {}(const Params_mhca ¶ms, cudaStream_t stream);".format( + lname + ) + elif effective_sm >= 90: + # need to set tma desc in params + return "void {}(Params_v{} ¶ms, const Launch_params &launch_params, cudaStream_t stream);".format( + lname, version + ) + else: + return "void {}(const Params_v{} ¶ms, const Launch_params &launch_params, cudaStream_t stream);".format( + lname, version + ) + + signatures = [] + for kspec, _fname, lname, _kname in specs_names: + effective_sm, _ = get_effective_sm_and_name(kspec) + use_tma = effective_sm == 90 and not kspec.ldgsts_q + signatures.append(get_signature(lname, kspec.version, kspec.cross_mha, use_tma)) + if kspec.has_noloop and not kspec.tiled: + signatures.append( + get_signature(lname + "_nl", kspec.version, kspec.cross_mha, use_tma) + ) + elif kspec.tiled: + signatures.append( + get_signature( + lname + "_nl_tiled", kspec.version, kspec.cross_mha, use_tma + ) + ) + if not kspec.warp_specialization: + signatures.append("void {}_get_max_heads_per_wave(int*);".format(lname)) + signatures = "\n".join(signatures) + + # v1 + # - normal + # - no loop + # v2 + # - normal + # - no loop + # - normal interleaved + # - no loop interleaved + # - flash attention no loop + # - flash attention no loop tiled + # - flash attention warp_specialized (on Hopper) + + def gen_unroll_check(kspec): + code = "if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b > {unroll_threshold})))".format( + **kspec._asdict() + ) + if kspec.flash_attention: + code = "if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b * h > {unroll_threshold})))".format( + **kspec._asdict() + ) + return code + + def gen_call(kspec, lname): + effective_sm, _ = get_effective_sm_and_name(kspec) + data_type = dtype2typename[kspec.dtype] + output_data_type = data_type + if kspec.output_dtype: + output_data_type = dtype2typename[kspec.output_dtype] + il_check = "" + if kspec.version == 2 and kspec.dtype in ["fp16", "bf16"]: + il_check += ( + "&& use_flash_attention " + if kspec.flash_attention + else "&& !use_flash_attention " + ) + if kspec.version == 2: + # attention input layout. + il_check += f"&& attention_input_layout == {kspec.input_layout.value} " + # interleaved layout or not. + il_check += "&& interleaved " if kspec.interleaved else "&& !interleaved " + if effective_sm == 90: + il_check += "&& !use_tma " if kspec.ldgsts_q else "&& use_tma " + il_check += ( + "&& warp_specialization " + if kspec.warp_specialization + else "&& !warp_specialization " + ) + else: + il_check += "&& !warp_specialization && !use_tma " + # Different accumulation types. + if "_fp32" in kspec.dtype or "bf16" in kspec.dtype or kspec.dtype == "e4m3": + il_check += "&& force_fp32_acc " + else: + il_check += "&& !force_fp32_acc " + # whether support alibi or not. + if kspec.warp_specialization: + il_check += ( + "&& params.has_alibi " if kspec.alibi else "&& !params.has_alibi " + ) + il_check += ( + "&& params.softmax_stats_ptr != nullptr " + if kspec.return_softmax_stats + else "&& params.softmax_stats_ptr == nullptr " + ) + # use enable_attn_logit_softcapping or not. + il_check += ( + "&& enable_attn_logit_softcapping " + if kspec.enable_attn_logit_softcapping + else "&& !enable_attn_logit_softcapping " + ) + # check sage block sizes + sage_block_size_q = 0 + sage_block_size_k = 0 + sage_block_size_v = 0 + if kspec.sage_block_sizes: + # override the data_type to output type, otherwise it is always E4M3 + data_type = output_data_type + sage_block_size_q = kspec.sage_block_sizes[0] + sage_block_size_k = kspec.sage_block_sizes[1] + sage_block_size_v = kspec.sage_block_sizes[2] + il_check += ( + f"&& sage_block_size_q == {sage_block_size_q} " + f"&& sage_block_size_k == {sage_block_size_k} " + f"&& sage_block_size_v == {sage_block_size_v} " + ) + + il_check += ( + "&& enable_skip_softmax " + if kspec.enable_skip_softmax + else "&& !enable_skip_softmax " + ) + + il_check += ( + "&& params.use_int8_scale_max " + if kspec.has_scale_max + else "&& !params.use_int8_scale_max " + ) + + slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0 + + ## NOTE: need to tune here + if kspec.has_noloop and not kspec.flash_attention: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && s == {slen} && d == {head_size} && sm == {sm} + {il_check}) {{ + + {unroll_check} {{ + {lname}(params, launch_params, stream); + }} else {{ + {lname}_nl(params, launch_params, stream); + }} + +}} """.format( + **kspec._asdict(), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + unroll_check=gen_unroll_check(kspec), + ) + + elif kspec.flash_attention: # NOTE: flash attention uses no_loop as default + # TypeError: got multiple values for keyword argument if using key 'head_size_v', so 'dv' instead + dv = kspec.head_size_v or kspec.head_size + if kspec.tiled: # higher precedence; does not require bh_upper_thres + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm} + {il_check} && use_tiled) {{ + + {lname}_nl_tiled(params, launch_params, stream); + +}} """.format( # type: ignore[str-format] + **kspec._asdict(), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + dv=dv, + ) + # warp specialization kernels need launch_params + elif kspec.warp_specialization: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm} + {il_check}) {{ + + {lname}(params, launch_params, stream); + +}} """.format( # type: ignore[str-format] + **kspec._asdict(), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + dv=dv, + ) + else: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm} + && !use_tiled {il_check}) {{ + + {lname}_nl(params, launch_params, stream); + +}} """.format( # type: ignore[str-format] + **kspec._asdict(), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + dv=dv, + ) + else: + call_stmt = """\ +if( data_type == {data_type} && output_data_type == {output_data_type} && s == {slen} && d == {head_size} && sm == {sm} + {il_check}) {{ + + {lname}(params, launch_params, stream); + +}} """.format( + **kspec._asdict(), + data_type=data_type, + output_data_type=output_data_type, + slen=slen, + lname=lname, + il_check=il_check, + ) + return call_stmt + + def gen_call_fmhca(kspec, lname): + effective_sm, _ = get_effective_sm_and_name(kspec) + data_type = dtype2typename[kspec.dtype] + il_check = "" + if kspec.version == 2: + il_check = "&& interleaved " if kspec.interleaved else "&& !interleaved " + if effective_sm == 90: + il_check += "&& !use_tma " if kspec.ldgsts_q else "&& use_tma " + il_check += ( + "&& params.use_int8_scale_max " + if kspec.has_scale_max + else "&& !params.use_int8_scale_max " + ) + + s_kv_len = kspec.seq_len + if kspec.has_noloop: + call_stmt = """\ +if( data_type == {data_type} && s_kv == {s_kv_len} && d == {head_size} && sm == {sm} {il_check}) {{ + + {unroll_check} {{ + {lname}(params, stream); + }} else {{ + {lname}_nl(params, stream); + }} + +}} """.format( + **kspec._asdict(), + data_type=data_type, + s_kv_len=s_kv_len, + lname=lname, + il_check=il_check, + unroll_check=gen_unroll_check(kspec), + ) + + else: + call_stmt = """\ +if( data_type == {data_type} && s_kv == {s_kv_len} && d == {head_size} && sm == {sm} {il_check}) {{ + {lname}(params, stream); + }} """.format( + **kspec._asdict(), + data_type=data_type, + s_kv_len=s_kv_len, + lname=lname, + il_check=il_check, + ) + return call_stmt + + calls_v2 = [ + gen_call(kspec, lname) + for kspec, fname, lname, kname in specs_names + if kspec.version == 2 and kspec.cross_mha == 0 + ] + + calls_v2 = "else ".join(calls_v2) if len(calls_v2) > 0 else "if( false ) {}" + + calls_v1 = [ + gen_call(kspec, lname) + for kspec, fname, lname, kname in specs_names + if kspec.version == 1 and kspec.cross_mha == 0 + ] + + calls_v1 = "else ".join(calls_v1) if len(calls_v1) > 0 else "if( false ) {}" + + calls_mhca = [ + gen_call_fmhca(kspec, lname) + for kspec, fname, lname, kname in specs_names + if kspec.cross_mha == 1 + ] + + calls_mhca = "else ".join(calls_mhca) if len(calls_mhca) > 0 else "if( false ) {}" + + def gen_warp_spec(kspec): + data_type = dtype2typename[kspec.dtype] + if kspec.sage_block_sizes is not None: + assert kspec.output_dtype is not None + # override the data_type to output type, otherwise it is always E4M3 + data_type = dtype2typename[kspec.output_dtype] + slen = kspec.seq_len * kspec.ctas_per_head + effective_sm, _ = get_effective_sm_and_name(kspec) + warp_spec_check = "" + nl_warps_m = kspec.warps_m if effective_sm == 90 else 1 + nl_warps_n = ( + kspec.warps_n if effective_sm == 90 else kspec.warps_m * kspec.warps_n + ) + if kspec.version == 2 and kspec.dtype in ["fp16", "bf16"]: + warp_spec_check += ( + "&& use_flash_attention " + if kspec.flash_attention + else "&& !use_flash_attention " + ) + if kspec.version == 2: + if effective_sm == 90: + warp_spec_check += "&& !use_tma " if kspec.ldgsts_q else "&& use_tma " + warp_spec_check += ( + "&& warp_specialization " + if kspec.warp_specialization + else "&& !warp_specialization " + ) + else: + warp_spec_check += "&& !use_tma && !warp_specialization " + + if kspec.flash_attention: # NOTE support any sequence + return """\ +if( data_type == {data_type} && d == {head_size} && sm == {sm} {warp_spec_check} + && version == {version} ) {{ + warps_m = {warps_m}; + warps_n = {warps_n}; +}} """.format( # type: ignore[str-format] + **locals(), **kspec._asdict(), unroll_check=gen_unroll_check(kspec) + ) + return """\ +if( data_type == {data_type} && s == {slen} && d == {head_size} && sm == {sm} {warp_spec_check} + && version == {version} ) {{ + {unroll_check} {{ + warps_m = {warps_m}; + warps_n = {warps_n}; + }} else {{ + warps_m = {nl_warps_m}; + warps_n = {nl_warps_n}; + }} +}} """.format(**locals(), **kspec._asdict(), unroll_check=gen_unroll_check(kspec)) + + warp_specs = "else ".join([gen_warp_spec(spec[0]) for spec in specs_names]) + if len(warp_specs) > 0: + warp_specs += 'else {\n\tassert(false && "Unsupported config");\n}' + + # Generate the cta spec. + def gen_cta_spec(spec): + kspec, _, lname, _ = spec + slen = kspec.seq_len * kspec.ctas_per_head + return """\ +if( data_type == {data_type} && s == {slen} && d == {head_size} && use_multi_ctas + && version == {version} ) {{ + + ctas_per_head = {ctas_per_head}; + {lname}_get_max_heads_per_wave(&max_heads_per_wave); + +}} """.format(**locals(), **kspec._asdict(), data_type=dtype2typename[kspec.dtype]) + + cta_specs = "else ".join( + [gen_cta_spec(spec) for spec in specs_names if spec[0].ctas_per_head > 1] + ) + + api_code = """\ +{copyright} +#pragma once + +#include +#include +#include +#include + +using Params_v1 = bert::Fused_multihead_attention_params_v1; +using Params_v2 = bert::Fused_multihead_attention_params_v2; +using Params_mhca = bert::Fused_multihead_attention_params_mhca; +using Launch_params = bert::Fused_multihead_attention_launch_params; + +{signatures} + +inline void run_fmha_v1(Params_v1 ¶ms, + const Launch_params &launch_params, + Data_type data_type, + Data_type output_data_type, + int sm, + cudaStream_t stream=0){{ +const size_t s = params.s; +const size_t b = params.b; +const size_t d = params.d; +const bool force_unroll = launch_params.force_unroll; +const bool ignore_b1opt = launch_params.ignore_b1opt; + +const bool use_flash_attention = false; + +{calls_v1} +else {{ + assert(false && "Unsupported config."); +}} + +}} + +// Note: transitioning to moving kernel launch parameters into launch_params to reduce the +// occurrences the interface needs to be modified +inline void run_fmha_v2(Params_v2 ¶ms, + const Launch_params &launch_params, + Data_type data_type, + Data_type output_data_type, + int sm, + cudaStream_t stream=0) {{ + +const size_t s = params.s; +const size_t b = params.b; +const size_t h = params.h; +const size_t d = params.d; +const size_t dv = params.dv; +const size_t sage_block_size_q = params.sage.q.block_size; +const size_t sage_block_size_k = params.sage.k.block_size; +const size_t sage_block_size_v = params.sage.v.block_size; + +const bool interleaved = launch_params.interleaved; +const bool force_unroll = launch_params.force_unroll; +const bool ignore_b1opt = launch_params.ignore_b1opt; +const bool force_fp32_acc = launch_params.force_fp32_acc; +const bool warp_specialization = launch_params.warp_specialization; +const bool use_tma = launch_params.use_tma; +const bool use_flash_attention = launch_params.flash_attention; +const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping; +const int attention_input_layout = static_cast(launch_params.attention_input_layout); +// tiled variant uses ldgsts +const bool use_tiled = launch_params.use_granular_tiling; + +{calls_v2} +else {{ + assert(false && "Unsupported config."); +}} + +}} + +#if __guard_fmhca_placeholder__ // fmhca api header + +inline void run_fmhca(Params_mhca ¶ms, + const Launch_params &launch_params, + Data_type data_type, + int sm, + cudaStream_t stream=0) {{ + +const size_t s_kv = params.s; +const size_t b = params.b; +const size_t d = params.d_padded; + +const bool interleaved = launch_params.interleaved; +const bool force_unroll = launch_params.force_unroll; +const bool ignore_b1opt = launch_params.ignore_b1opt; + +{calls_mhca} +else {{ + assert(false && "Unsupported config"); +}} + +}} + +#endif // fmhca api header + +inline std::tuple get_warps(Launch_params& launch_params, + int sm, + Data_type data_type, + size_t s, + size_t b, + size_t d, + int version) {{ + size_t warps_m, warps_n, warps_k = 1; + const bool interleaved = launch_params.interleaved; + const bool use_tma = launch_params.use_tma; + const bool force_unroll = launch_params.force_unroll; + const bool ignore_b1opt = launch_params.ignore_b1opt; + const bool use_flash_attention = launch_params.flash_attention; + // tiled variant uses ldgsts + const bool use_tiled = launch_params.use_granular_tiling; + const bool warp_specialization = launch_params.warp_specialization; + +{warp_specs} + + return std::make_tuple(warps_m, warps_n, warps_k); +}} + +// The constant is defined in "setup.py". +constexpr int MAX_STGS_PER_LOOP = {MAX_STGS_PER_LOOP}; + +// The number of CTAs and threads per CTA to launch the kernel. +inline void get_grid_size(int &heads_per_wave, + int &ctas_per_head, + int sm, + Data_type data_type, + size_t b, + size_t s, + size_t h, + size_t d, + bool use_multi_ctas, + int version) {{ + + // Determine the number of CTAs per head (kernel constant). + int max_heads_per_wave = 0; + ctas_per_head = 1; + heads_per_wave = b*h; +{cta_specs} + + // Adjust the number of heads per wave. + if( heads_per_wave > max_heads_per_wave ) {{ + heads_per_wave = max_heads_per_wave; + }} +}} + +""".format(**locals(), copyright=copyright, MAX_STGS_PER_LOOP=MAX_STGS_PER_LOOP) + return api_code diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index bb6962b791..8922eb7e95 100755 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -38,6 +38,7 @@ ) from .utils import generate_additional_params from .fmha_v2.generate_kernels import enumerate_kernels +from .fmha_v2.fmha_library import generate_jit_sources def get_single_decode_uri( @@ -1899,12 +1900,7 @@ def gen_cudnn_fmha_module(): ) -def get_trtllm_fmha_v2_module(): - module = gen_trtllm_fmha_v2_module().build_and_load() - return module - - -def gen_trtllm_fmha_v2_module() -> JitSpec: +def gen_trtllm_fmha_v2_sm120_module() -> JitSpec: uri = "trtllm_fmha_v2" cached_ops = jit_env.FLASHINFER_JIT_DIR / uri cached_ops.mkdir(parents=True, exist_ok=True) @@ -1938,3 +1934,71 @@ def gen_trtllm_fmha_v2_module() -> JitSpec: source_paths, extra_cuda_cflags=nvcc_flags, ) + + +def gen_fmha_v2_module( + input_layout: str, input_dtype: torch.dtype, output_dtype: torch.dtype = None +) -> JitSpec: + # Setup generated source directory + if output_dtype is None: + output_dtype = input_dtype + + dtype_map = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "e4m3", + } + input_dtype_str = dtype_map[input_dtype] + output_dtype_str = dtype_map[output_dtype] if output_dtype is not None else None + + uri = f"trtllm_fmha_v2_{input_layout.lower()}_{input_dtype_str}_{output_dtype_str}" + + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + gen_directory.mkdir(parents=True, exist_ok=True) + + # Source directories + csrc_dir = jit_env.FLASHINFER_CSRC_DIR + fmha_v2_src_dir = csrc_dir / "fmha_v2" + # Determine which SM major versions are available in the compilation context + # so we only generate kernel sources for architectures that will be compiled. + source_paths = generate_jit_sources( + uri, + input_layout, + input_dtype_str, + output_dtype_str, + compilation_context=current_compilation_context, + ) + + # copy static fmha_v2_run.cu + static_run_path = csrc_dir / "fmha_v2_run.cu" + run_path = gen_directory / "fmha_v2_run.cu" + with open(static_run_path, "r") as f: + write_if_different(run_path, f.read()) + source_paths.append(run_path) + + # copy static fmha_v2_jit_binding.cu + static_binding_path = csrc_dir / "fmha_v2_jit_binding.cu" + binding_path = gen_directory / "fmha_v2_jit_binding.cu" + with open(static_binding_path, "r") as f: + write_if_different(binding_path, f.read()) + source_paths.append(binding_path) + + # Setup compilation flags + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[9, 12] + ) + nvcc_flags.extend( + [ + f"-I{fmha_v2_src_dir}", + f"-I{gen_directory}", # For fmha_v2_api.h + f"-I{jit_env.FLASHINFER_CSRC_DIR / 'fmha_v2'}", + f"-I{jit_env.FLASHINFER_INCLUDE_DIR}", # For flashinfer headers + "-Wno-deprecated-gpu-targets", + ] + ) + + return gen_jit_spec( + uri, + source_paths, + extra_cuda_cflags=nvcc_flags, + ) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index ab6881b786..79a0256830 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -27,12 +27,13 @@ gen_batch_prefill_module, gen_customize_batch_prefill_module, gen_fmha_cutlass_sm100a_module, + gen_fmha_v2_module, gen_single_prefill_module, get_batch_prefill_uri, get_single_prefill_uri, setup_cubin_loader, gen_trtllm_gen_fmha_module, - get_trtllm_fmha_v2_module, + gen_trtllm_fmha_v2_sm120_module, ) from .cudnn import cudnn_batch_prefill_with_kv_cache from .page import get_seq_lens @@ -86,6 +87,45 @@ def _split_scale_param(scale): return None, float(scale) +def _create_scale_bmm2_d_tensor( + scale_bmm2: float, data_dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """Create a scale_bmm2_d tensor with the correct bit pattern for the TRT-LLM FMHAv2 kernel. + + This function replicates the C++ set_alpha logic for scale_type2 to avoid + cudaMemcpy synchronization in the kernel. The scale value is converted to + the appropriate floating-point format and stored as int32 bits on device. + + The scale_type2 logic (from C++): + - FP16 input -> scale stored as FP16 bits in lower 16 bits of uint32 + - BF16 input -> scale stored as BF16 bits in lower 16 bits of uint32 + - Other (FP8, INT8, etc.) -> scale stored as FP32 bits in uint32 + + Args: + scale_bmm2: The scale value for BMM2 (typically 1.0) + data_dtype: The input tensor dtype (determines scale_type2) + device: The target device for the tensor + + Returns: + A 1-element int32 tensor on device containing the scale bits + """ + if data_dtype == torch.float16: + # Create int32 buffer on device, write FP16 value to lower 16 bits via view + result = torch.zeros(1, dtype=torch.int32, device=device) + result.view(torch.float16)[0] = scale_bmm2 + return result + elif data_dtype == torch.bfloat16: + # Create int32 buffer on device, write BF16 value to lower 16 bits via view + result = torch.zeros(1, dtype=torch.int32, device=device) + result.view(torch.bfloat16)[0] = scale_bmm2 + return result + else: + # FP8, INT8, etc. use FP32 accumulation - create FP32 tensor and view as int32 + return torch.tensor([scale_bmm2], dtype=torch.float32, device=device).view( + torch.int32 + ) + + @functools.cache def get_fmha_module( dtype_q: torch.dtype, @@ -3788,6 +3828,11 @@ def trtllm_batch_context_with_kv_cache( ) +@functools.cache +def get_trtllm_fmha_v2_sm120_module(): + return gen_trtllm_fmha_v2_sm120_module().build_and_load() + + @flashinfer_api def fmha_v2_prefill_deepseek( query: torch.Tensor, @@ -3849,7 +3894,7 @@ def fmha_v2_prefill_deepseek( assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) - module = get_trtllm_fmha_v2_module() + module = get_trtllm_fmha_v2_sm120_module() is_e4m3 = query.dtype == torch.float8_e4m3fn is_bf16_output = out.dtype == torch.bfloat16 scale_softmax = ( @@ -3876,3 +3921,328 @@ def fmha_v2_prefill_deepseek( return out, lse else: return out + + +@functools.cache +def get_trtllm_fmha_v2_module( + input_layout: str, input_dtype: torch.dtype, output_dtype: torch.dtype = None +): + return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load() + + +@flashinfer_api +def trtllm_fmha_v2_prefill( + qkv: Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ], + input_layout: str, + workspace_buffer: torch.Tensor, + seq_lens: torch.Tensor, + max_q_len: int, + max_kv_len: int, + bmm1_scale: float, + bmm2_scale: float, + batch_size: int, + cum_seq_lens_q: torch.Tensor, + cum_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[Union[torch.dtype, str]] = None, + sinks: Optional[List[torch.Tensor]] = None, + pos_encoding_mode: str = None, + logits_soft_cap_scale: Optional[float] = None, + mask_mode: str = "causal", + window_left: int = -1, + chunked_attention_size: int = 0, + save_softmax_stats: bool = False, + skip_softmax_threshold_scale_factor: float = 0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""TRT-LLM FMHAv2 prefill attention. + + Parameters + ---------- + qkv + Query/key/value input; expected format is determined by :attr:`input_layout`. + input_layout + Specifies the layout of the query/key/value tensors: + - ``PACKED_QKV``: ``qkv`` is a single tensor of shape + ``[num_tokens, 3, num_heads, head_dim]``. + - ``CONTIGUOUS_Q_KV``: ``qkv`` is ``(Q, KV)`` where Q has shape + ``[num_tokens, num_heads, head_dim]`` and KV has shape + ``[num_tokens, 2, num_kv_heads, head_dim]`` + (``KV[:, 0, ...]`` is key, ``KV[:, 1, ...]`` is value). + - ``Q_PAGED_KV_HND``: ``qkv`` is ``(Q, paged_KV)`` where Q has shape + ``[num_tokens, num_heads, head_dim]`` and paged_KV has shape + ``[num_pages, 2, num_kv_heads, page_size, head_dim]`` + (``paged_KV[:, 0, ...]`` is key cache, ``paged_KV[:, 1, ...]`` is value cache). + - ``Q_PAGED_KV_NHD``: same as ``Q_PAGED_KV_HND`` but paged_KV shape is + ``[num_pages, 2, page_size, num_kv_heads, head_dim]``. + - ``SEPARATE_Q_K_V``: ``qkv`` is ``(Q, K, V)`` where Q has shape + ``[num_tokens, num_heads, head_dim]`` and K, V have shape + ``[num_tokens, num_kv_heads, head_dim]``. + workspace_buffer + The workspace buffer. Must be initialized to 0 for its first use. + seq_lens + The KV sequence length of each request, shape: ``[batch_size]``. + max_q_len + The maximum sequence length for query. + max_kv_len + The maximum sequence length for KV cache. + bmm1_scale + The fused scale for BMM1 (QK^T) computation. + bmm2_scale + The fused scale for BMM2 (softmax(QK^T) * V) computation. + batch_size + The batch size. + cum_seq_lens_q + The cumulative sequence lengths for query, shape: ``[batch_size + 1]``. + cum_seq_lens_kv + The cumulative sequence lengths for KV cache, shape: ``[batch_size + 1]``. + block_tables + The page table for KV cache, shape: ``[batch_size, max_num_pages_per_seq]``. + Required when using paged KV cache format. + out + The output tensor. If not provided, will be allocated with ``out_dtype``. + If ``out_dtype`` is also not provided, will use the dtype of query. + out_dtype + The output dtype. If not provided, will use the dtype of ``out`` or query. + sinks + Additional value per head in the denominator of the softmax. + pos_encoding_mode + The position encoding mode, could be ``alibi``. Defaults to ``None``. + logits_soft_cap_scale + The logits soft cap scale. Defaults to ``None``, which means no soft cap. + mask_mode + The mask mode, could be ``causal``, ``sliding_window``, or ``chunked``. + Defaults to ``causal``. + window_left + The left (inclusive) window size for the attention window, when set to ``-1``, + the window size will be set to the full length of the sequence. Defaults to ``-1``. + Only effective when :attr:`mask_mode` is ``sliding_window``. + chunked_attention_size + The chunked attention size. Defaults to ``0``, which means no chunked attention. + Only effective when :attr:`mask_mode` is ``chunked``. Must be a power of 2. + save_softmax_stats + Whether to save the softmax statistics. Defaults to ``False``. + skip_softmax_threshold_scale_factor + The factor of skip-softmax (Sparse Attention), + Skip softmax and BMM2 when exp(local_max - global_max) < threshold, + where threshold = skip_softmax_threshold_scale_factor / seqlen. + Defaults to ``0`` (disabled). + Returns + ------- + If :attr:`save_softmax_stats` is ``False``, the attention output tensor. + If :attr:`save_softmax_stats` is ``True``, a tuple of two tensors: + * The attention output tensor. + * The softmax statistics tensor (LSE). + """ + + if input_layout == "PACKED_QKV": + assert isinstance(qkv, torch.Tensor) + if qkv.dim() != 4 or qkv.shape[1] != 3: + raise ValueError( + f"PACKED_QKV expects shape [tokens, 3, num_heads, head_dim], got {tuple(qkv.shape)}" + ) + query = qkv + k_cache, v_cache = qkv, qkv # placeholders + elif input_layout == "CONTIGUOUS_Q_KV": + assert isinstance(qkv, tuple) + query, kv_cache = qkv[0], qkv[1] + if kv_cache.dim() != 4 or kv_cache.shape[1] != 2: + raise ValueError( + f"CONTIGUOUS_Q_KV expects KV shape [tokens, 2, num_kv_heads, head_dim], got {tuple(kv_cache.shape)}" + ) + k_cache = kv_cache + v_cache = kv_cache # placeholder (not used for this layout) + elif input_layout == "Q_PAGED_KV_NHD": + assert isinstance(qkv, tuple) + query, paged_kv = qkv[0], qkv[1] + if paged_kv.dim() != 5 or paged_kv.shape[1] != 2: + raise ValueError( + f"Q_PAGED_KV_NHD expects paged_KV shape [pages, 2, page_size, num_kv_heads, head_dim], got {tuple(paged_kv.shape)}" + ) + # TODO: implement native NHD support in the kernel to avoid this transpose + kv_cache = paged_kv.transpose(-3, -2).contiguous() + k_cache, v_cache = kv_cache.unbind(dim=1) + elif input_layout == "Q_PAGED_KV_HND": + assert isinstance(qkv, tuple) + query, paged_kv = qkv[0], qkv[1] + if paged_kv.dim() != 5 or paged_kv.shape[1] != 2: + raise ValueError( + f"Q_PAGED_KV_HND expects paged_KV shape [pages, 2, num_kv_heads, page_size, head_dim], got {tuple(paged_kv.shape)}" + ) + k_cache, v_cache = paged_kv.unbind(dim=1) + elif input_layout == "SEPARATE_Q_K_V": + assert isinstance(qkv, tuple) + query, k_cache, v_cache = qkv[0], qkv[1], qkv[2] + if hasattr(torch, "float8_e4m3fn") and query.dtype == torch.float8_e4m3fn: + raise ValueError( + "FP8 (e4m3) is not supported for the SEPARATE_Q_K_V input layout. " + "Use PACKED_QKV, CONTIGUOUS_Q_KV, or Q_PAGED_KV layout instead." + ) + if logits_soft_cap_scale is not None and logits_soft_cap_scale > 0: + raise ValueError( + "Logits soft capping is not supported for the SEPARATE_Q_K_V input layout. " + "Use PACKED_QKV, CONTIGUOUS_Q_KV, or Q_PAGED_KV layout instead." + ) + else: + raise ValueError( + f"Unsupported input_layout: {input_layout!r}. Expected one of: " + "PACKED_QKV, CONTIGUOUS_Q_KV, Q_PAGED_KV_HND, Q_PAGED_KV_NHD, SEPARATE_Q_K_V." + ) + + if input_layout == "PACKED_QKV": + # Packed QKV: query is [tokens, 3, H, D] + num_qo_heads = query.shape[2] + page_size = 0 # Not applicable for packed layouts + head_dim_v = query.shape[3] # Assume same as head_dim_qk + elif input_layout in ("Q_PAGED_KV_NHD", "Q_PAGED_KV_HND"): + # Q is 3D: [tokens, H, D], Paged KV (HND after any transpose): [num_pages, H_kv, page_size, D] + num_qo_heads = query.shape[1] + page_size = k_cache.shape[2] + head_dim_v = v_cache.shape[3] + elif input_layout == "CONTIGUOUS_Q_KV": + # Q is 3D: [tokens, H, D], KV is 4D: [tokens, 2, H_kv, D] + # k_cache holds the combined KV tensor + num_qo_heads = query.shape[1] + page_size = 0 # Not applicable for non-paged layouts + head_dim_v = k_cache.shape[3] # D from KV tensor + else: + # SEPARATE_Q_K_V: all 3D ragged [tokens, H, D] + num_qo_heads = query.shape[1] + page_size = 0 # Not applicable for non-paged layouts + head_dim_v = v_cache.shape[2] + + uses_sliding_window = window_left is not None and window_left >= 0 + uses_chunked = chunked_attention_size is not None and chunked_attention_size > 0 + is_non_causal = mask_mode is not None and mask_mode.lower() == "padding" + + if (uses_sliding_window or uses_chunked) and is_sm12x_supported(query.device): + feature = "Sliding window" if uses_sliding_window else "Chunked" + raise ValueError( + f"{feature} attention is not yet supported for FMHAv2 on SM120 (Blackwell). " + "Only CAUSAL masks are available. " + ) + + if (uses_sliding_window or uses_chunked) and is_non_causal: + feature = "Sliding window" if uses_sliding_window else "Chunked" + raise ValueError( + f"{feature} attention requires causal masking. " + f"The underlying kernel only supports SLIDING_OR_CHUNKED_CAUSAL mode. " + ) + + # Determine output dtype + if out is not None: + o_dtype = out.dtype + elif out_dtype is not None: + o_dtype = canonicalize_torch_dtype(out_dtype) + else: + o_dtype = query.dtype + + # Allocate output tensor if not provided + # Use head_dim_v (actual value) not head_dim_vo (0 means "same as head_dim_qk") + if out is None: + out = torch.empty( + (query.shape[0], num_qo_heads, head_dim_v), + dtype=o_dtype, + device=query.device, + ) + + # Handle scale parameters + scale_bmm1 = float(bmm1_scale) + scale_bmm2 = float(bmm2_scale) + + # Softmax scale: 1.0 for FP8, 0.0 (auto-detect) for FP16/BF16 + # C++ kernel auto-sets to 1.0 for FP16/E4M3 when 0.0 is passed + is_e4m3 = ( + query.dtype == torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else False + ) + if is_e4m3: + if is_sm12x_supported(query.device): + raise ValueError( + "FP8 (e4m3) is not yet supported for FMHAv2 on SM120 (Blackwell). " + "Use fp16 or bf16 instead." + ) + if uses_sliding_window and input_layout in ("PACKED_QKV", "CONTIGUOUS_Q_KV"): + _num_kv_heads = ( + num_qo_heads if input_layout == "PACKED_QKV" else k_cache.shape[2] + ) + if batch_size == 16 and _num_kv_heads == 4 and head_dim_v == 256: + raise ValueError( + "FP8 (e4m3) sliding window attention with batch_size=16, " + "num_kv_heads=4, head_dim=256 is not supported for " + f"{input_layout} layout due to a known issue." + ) + scale_softmax = 1.0 if is_e4m3 else 1.0 + softcapping_scale = ( + logits_soft_cap_scale if logits_soft_cap_scale is not None else 0.0 + ) + + module = get_trtllm_fmha_v2_module( + input_layout, + query.dtype, + o_dtype if query.dtype == torch.float8_e4m3fn else None, + ) + + # Allocate LSE tensor if saving softmax stats + # Kernel writes in ragged (flat) format: [total_q_tokens, num_qo_heads, 2] + # total_q_tokens == query.shape[0] for all ragged layouts + lse = None + if save_softmax_stats: + lse = torch.empty( + (query.shape[0], num_qo_heads, 2), + dtype=torch.float32, + device=query.device, + ) + + # For Q_PAGED_KV layout, expand block_tables from [B, M] to [B, 2, M] + # TRT-LLM kernel expects separate K and V block offset arrays. + # FlashInfer layout: K for page i is at block index 2*i, V at 2*i+1 + expanded_block_tables = None + if block_tables is not None and input_layout.lower().startswith("q_paged_kv"): + # K offsets = page_idx * 2 (even blocks) + # V offsets = page_idx * 2 + 1 (odd blocks) + expanded_block_tables = torch.stack( + [block_tables * 2, block_tables * 2 + 1], dim=1 + ).contiguous() # [B, 2, M] + + scale_bmm2_d = _create_scale_bmm2_d_tensor(scale_bmm2, query.dtype, query.device) + + module.run( + query, # Q tensor + k_cache, # K tensor + v_cache, # V tensor + out, # Output tensor + workspace_buffer, # Workspace buffer + workspace_buffer.numel() + * workspace_buffer.element_size(), # Workspace buffer size in bytes + expanded_block_tables, # Expanded block tables [B, 2, M] or None + page_size, + seq_lens, # Sequence length for kv_cache + cum_seq_lens_q, # Cumulative sequence length for query + cum_seq_lens_kv, # Cumulative sequence length for kv_cache + input_layout.lower(), # Input layout + max_q_len, # Max sequence length for query + max_kv_len, # Max sequence length for kv_cache + batch_size, # Batch size + mask_mode.lower(), # Attention mask type + scale_softmax, # Softmax scale + scale_bmm1, # BMM1 scale + scale_bmm2, # BMM2 scale (float, still needed for set_alpha in C++) + window_left, # Window left + chunked_attention_size, # Chunked attention size + pos_encoding_mode is not None + and pos_encoding_mode.lower() == "alibi", # Alibi mode + softcapping_scale, # Softcapping scale (0.0 = disabled) + skip_softmax_threshold_scale_factor, # threshold_scale_factor for skip-softmax (0.0 = disable) + scale_bmm2_d, # Pre-populated scale_bmm2 on device (avoids cudaMemcpy) + lse, # Optional LSE tensor (None if not saving softmax stats) + sinks, # Optional sinks tensor + ) + + if save_softmax_stats: + return out, lse + else: + return out diff --git a/tests/attention/test_fmha_v2_prefill.py b/tests/attention/test_fmha_v2_prefill.py new file mode 100644 index 0000000000..c31aeb2b94 --- /dev/null +++ b/tests/attention/test_fmha_v2_prefill.py @@ -0,0 +1,1336 @@ +import pytest +import torch +import math +from typing import Optional, Tuple, Union + +import flashinfer +from flashinfer.prefill import fmha_v2_prefill_deepseek +from tests.utils_fp8 import to_float8 +from flashinfer.utils import is_sm12x_supported, is_sm120a_supported + +_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +_workspace_buffer: Optional[torch.Tensor] = None + + +def _get_workspace_buffer() -> torch.Tensor: + """Return a lazily-allocated, module-level workspace buffer that is reused + across test cases to avoid repeated 128 MiB CUDA allocations.""" + global _workspace_buffer + if _workspace_buffer is None: + _workspace_buffer = torch.zeros( + _WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + ) + else: + _workspace_buffer.zero_() + return _workspace_buffer + + +def attention_mla_ref_torch( + batch_size: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + # tensors are (batch_size, seq_len, num_heads, head_dim) + qo_len = q.shape[1] + kv_len = k.shape[1] + logits = torch.einsum("bmhd,bnhd->bhmn", q.float(), k.float()) * sm_scale + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len, kv_len, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + # LSE computation: logsumexp over the key dimension (last dim) + # logits shape: (batch, num_heads, seq_len, seq_len) + lse_ref = torch.logsumexp(logits, -1) # (batch, num_heads, seq_len) + # Transpose to match expected shape (batch, seq_len, num_heads) + lse_ref = lse_ref.transpose(1, 2) + p = torch.softmax(logits, dim=-1) + o_ref = torch.einsum("bhmn,bnhd->bmhd", p, v.float()).contiguous() + + # Return LSE in natural log (no conversion needed) + return o_ref, lse_ref + + +def attention_ref_torch( + qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + seq_lens: torch.Tensor, + cum_seq_lens_q: torch.Tensor, + sm_scale: float, + q_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, + causal: bool = True, + window_left: int = -1, + logits_soft_cap: float = 0.0, + block_tables: Optional[torch.Tensor] = None, + cum_seq_lens_kv: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pure-torch reference for attention supporting multiple input layouts. + + Layouts (auto-detected from qkv): + - Packed QKV: single tensor [total_tokens, 3, num_heads, head_dim] + - Separate Q, K, V: tuple (q, k, v) + - Contiguous Q + KV: tuple (q, kv) where kv is 4-D + - Paged Q + KV cache: tuple (q, paged_kv_cache) where paged_kv_cache is 5-D + (requires block_tables) + + Returns output tensor, or (output, lse) when return_lse=True. + LSE shape: [total_tokens, num_qo_heads]. + """ + device = seq_lens.device + batch_size = seq_lens.shape[0] + + if cum_seq_lens_kv is None: + cum_seq_lens_kv = cum_seq_lens_q + + # --- parse input layout --- + is_paged = False + if isinstance(qkv, torch.Tensor): + # Packed QKV: [total, 3, H, D] + q_flat = qkv[:, 0, :, :].contiguous() + k_flat = qkv[:, 1, :, :].contiguous() + v_flat = qkv[:, 2, :, :].contiguous() + elif isinstance(qkv, tuple): + if len(qkv) == 3: + q_flat, k_flat, v_flat = qkv + elif len(qkv) == 2: + q_flat = qkv[0] + second = qkv[1] + if second.ndim == 5: + # Paged: (q, paged_kv_cache[num_pages, 2, page_size, H_kv, D]) + is_paged = True + paged_kv_cache = second + page_size = paged_kv_cache.shape[2] + else: + # Contiguous: (q, kv[total, 2, H_kv, D]) + k_flat = second[:, 0, :, :].contiguous() + v_flat = second[:, 1, :, :].contiguous() + else: + raise ValueError(f"Unexpected tuple length: {len(qkv)}") + else: + raise TypeError(f"Unexpected qkv type: {type(qkv)}") + + num_qo_heads = q_flat.shape[1] + head_dim = q_flat.shape[2] + if is_paged: + num_kv_heads = paged_kv_cache.shape[3] + else: + num_kv_heads = k_flat.shape[1] + heads_per_group = num_qo_heads // num_kv_heads + + q_float = q_flat.float() * q_scale + + outputs = [] + lse_outputs = [] + for b in range(batch_size): + seq_len = seq_lens[b].item() + q_start = cum_seq_lens_q[b].item() + q_end = cum_seq_lens_q[b + 1].item() + q_len = q_end - q_start + q_seq = q_float[q_start:q_end] + + if is_paged: + num_pages_needed = (seq_len + page_size - 1) // page_size + k_pages = [] + v_pages = [] + for p in range(num_pages_needed): + page_idx = block_tables[b, p].item() + k_pages.append(paged_kv_cache[page_idx, 0]) + v_pages.append(paged_kv_cache[page_idx, 1]) + k_seq = torch.cat(k_pages, dim=0)[:seq_len].float() * k_scale + v_seq = torch.cat(v_pages, dim=0)[:seq_len].float() * v_scale + else: + kv_start = cum_seq_lens_kv[b].item() + kv_end = cum_seq_lens_kv[b + 1].item() + k_seq = k_flat[kv_start:kv_end].float() * k_scale + v_seq = v_flat[kv_start:kv_end].float() * v_scale + + o_seq = torch.zeros( + q_len, num_qo_heads, head_dim, dtype=torch.float32, device=device + ) + if return_lse: + lse_seq = torch.zeros( + q_len, num_qo_heads, dtype=torch.float32, device=device + ) + + for h in range(num_qo_heads): + kv_h = h // heads_per_group + + q_h = q_seq[:, h, :] + k_h = k_seq[:, kv_h, :] + v_h = v_seq[:, kv_h, :] + + scores = torch.matmul(q_h, k_h.t()) * sm_scale + + if logits_soft_cap > 0.0: + scores = logits_soft_cap * torch.tanh(scores / logits_soft_cap) + + if causal: + q_indices = torch.arange(q_len, device=device).unsqueeze(1) + kv_indices = torch.arange(seq_len, device=device).unsqueeze(0) + offset = seq_len - q_len + causal_mask = (q_indices + offset) >= kv_indices + scores = scores.masked_fill(~causal_mask, float("-inf")) + + if window_left >= 0: + q_indices = torch.arange(q_len, device=device).unsqueeze(1) + kv_indices = torch.arange(seq_len, device=device).unsqueeze(0) + offset = seq_len - q_len + window_mask = kv_indices >= (q_indices + offset - window_left) + scores = scores.masked_fill(~window_mask, float("-inf")) + + if return_lse: + lse_seq[:, h] = torch.logsumexp(scores, dim=-1) + + attn = torch.softmax(scores, dim=-1) + o_seq[:, h, :] = torch.matmul(attn, v_h) + + outputs.append(o_seq) + if return_lse: + lse_outputs.append(lse_seq) + + out = torch.cat(outputs, dim=0) + if return_lse: + return out, torch.cat(lse_outputs, dim=0) + return out + + +def chunked_attention_ref_torch( + qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + seq_lens: torch.Tensor, + cum_seq_lens_q: torch.Tensor, + sm_scale: float, + chunked_attention_size: int, + cum_seq_lens_kv: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Pure-torch reference for chunked causal attention. + + Chunked attention divides the KV-space into non-overlapping chunks of + ``chunked_attention_size`` tokens. Each query token (mapped to its absolute + KV-space position) can only attend causally *within its own chunk*. + + For a query at absolute KV position ``row``: + chunk_start = floor(row / chunk_size) * chunk_size + mask: chunk_start <= col <= row + + This matches the TRT-LLM FMHAv2 kernel logic in + ``csrc/fmha_v2/fmha/warpspec/epilogue.h:compute_sliding_window_or_chunk_start``. + """ + device = seq_lens.device + batch_size = seq_lens.shape[0] + + if cum_seq_lens_kv is None: + cum_seq_lens_kv = cum_seq_lens_q + + # --- parse input layout --- + is_paged = False + if isinstance(qkv, torch.Tensor): + q_flat = qkv[:, 0, :, :].contiguous() + k_flat = qkv[:, 1, :, :].contiguous() + v_flat = qkv[:, 2, :, :].contiguous() + elif isinstance(qkv, tuple): + if len(qkv) == 3: + q_flat, k_flat, v_flat = qkv + elif len(qkv) == 2: + q_flat = qkv[0] + second = qkv[1] + if second.ndim == 5: + is_paged = True + paged_kv_cache = second + page_size = paged_kv_cache.shape[2] + else: + k_flat = second[:, 0, :, :].contiguous() + v_flat = second[:, 1, :, :].contiguous() + else: + raise ValueError(f"Unexpected tuple length: {len(qkv)}") + else: + raise TypeError(f"Unexpected qkv type: {type(qkv)}") + + num_qo_heads = q_flat.shape[1] + head_dim = q_flat.shape[2] + if is_paged: + num_kv_heads = paged_kv_cache.shape[3] + else: + num_kv_heads = k_flat.shape[1] + heads_per_group = num_qo_heads // num_kv_heads + + q_float = q_flat.float() + + outputs = [] + for b in range(batch_size): + kv_len = seq_lens[b].item() + q_start = cum_seq_lens_q[b].item() + q_end = cum_seq_lens_q[b + 1].item() + q_len = q_end - q_start + q_seq = q_float[q_start:q_end] + + if is_paged: + num_pages_needed = (kv_len + page_size - 1) // page_size + k_pages = [] + v_pages = [] + for p in range(num_pages_needed): + page_idx = block_tables[b, p].item() + k_pages.append(paged_kv_cache[page_idx, 0]) + v_pages.append(paged_kv_cache[page_idx, 1]) + k_seq = torch.cat(k_pages, dim=0)[:kv_len].float() + v_seq = torch.cat(v_pages, dim=0)[:kv_len].float() + else: + kv_start = cum_seq_lens_kv[b].item() + kv_end = cum_seq_lens_kv[b + 1].item() + k_seq = k_flat[kv_start:kv_end].float() + v_seq = v_flat[kv_start:kv_end].float() + + o_seq = torch.zeros( + q_len, num_qo_heads, head_dim, dtype=torch.float32, device=device + ) + + # Absolute KV-space positions for each query token. + # Query token i corresponds to KV position (kv_len - q_len + i). + offset = kv_len - q_len + + for h in range(num_qo_heads): + kv_h = h // heads_per_group + q_h = q_seq[:, h, :] + k_h = k_seq[:, kv_h, :] + v_h = v_seq[:, kv_h, :] + + # scores: [q_len, kv_len] + scores = torch.matmul(q_h, k_h.t()) * sm_scale + + # Build chunked causal mask. + # q_abs[i] = offset + i (absolute KV-space row for query token i) + # kv_pos[j] = j (KV column index) + q_abs = torch.arange(q_len, device=device).unsqueeze(1) + offset + kv_pos = torch.arange(kv_len, device=device).unsqueeze(0) + + # Causal: col <= row + causal_mask = kv_pos <= q_abs + + # Chunk left boundary: col >= floor(row / chunk_size) * chunk_size + chunk_start = (q_abs // chunked_attention_size) * chunked_attention_size + chunk_mask = kv_pos >= chunk_start + + mask = causal_mask & chunk_mask + scores = scores.masked_fill(~mask, float("-inf")) + + attn = torch.softmax(scores, dim=-1) + o_seq[:, h, :] = torch.matmul(attn, v_h) + + outputs.append(o_seq) + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_dim_qk", [192]) +@pytest.mark.parametrize("head_dim_v", [128]) +@pytest.mark.parametrize("seq_len", [1024, 4096, 8192]) +@pytest.mark.parametrize( + "qkv_dtype,o_dtype", + [ + (torch.bfloat16, torch.bfloat16), + (torch.float8_e4m3fn, torch.bfloat16), + (torch.float8_e4m3fn, torch.float16), + ], +) +def test_fmha_v2_prefill_deepseek( + batch_size, num_heads, head_dim_qk, head_dim_v, seq_len, qkv_dtype, o_dtype +): + if not is_sm12x_supported(torch.device("cuda")): + pytest.skip("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") + torch.manual_seed(42) + + def initialize_tensors(batch_size, num_heads, head_dim_qk, head_dim_v, seq_len): + device = "cuda" + if qkv_dtype == torch.float8_e4m3fn: + q = torch.randn( + (batch_size, seq_len, num_heads, head_dim_qk), + dtype=torch.bfloat16, + device=device, + ) + k = torch.randn( + (batch_size, seq_len, num_heads, head_dim_qk), + dtype=torch.bfloat16, + device=device, + ) + v = torch.randn( + (batch_size, seq_len, num_heads, head_dim_v), + dtype=torch.bfloat16, + device=device, + ) + + q, q_scale = to_float8(q, dtype=torch.float8_e4m3fn) + k, k_scale = to_float8(k, dtype=torch.float8_e4m3fn) + v, v_scale = to_float8(v, dtype=torch.float8_e4m3fn) + q_scale = q_scale.item() + k_scale = k_scale.item() + v_scale = v_scale.item() + else: + q = torch.randn( + (batch_size, seq_len, num_heads, head_dim_qk), + dtype=qkv_dtype, + device=device, + ) + k = torch.randn( + (batch_size, seq_len, num_heads, head_dim_qk), + dtype=qkv_dtype, + device=device, + ) + v = torch.randn( + (batch_size, seq_len, num_heads, head_dim_v), + dtype=qkv_dtype, + device=device, + ) + # For non-FP8 case, scales are 1.0 + q_scale = 1.0 + k_scale = 1.0 + v_scale = 1.0 + + # Output and statistics + o = torch.zeros( + batch_size, seq_len, num_heads, head_dim_v, dtype=o_dtype, device=device + ) + lse = torch.zeros( + batch_size, seq_len, num_heads, 2, dtype=torch.float, device=device + ) + sm_scale = 1.0 / math.sqrt(head_dim_qk) + return q, k, v, o, lse, sm_scale, q_scale, k_scale, v_scale + + q, k, v, o, lse, sm_scale, q_scale, k_scale, v_scale = initialize_tensors( + batch_size, num_heads, head_dim_qk, head_dim_v, seq_len + ) + scale_bmm1 = q_scale * k_scale * sm_scale + scale_bmm2 = v_scale + scale_softmax = 1.0 if qkv_dtype == torch.float8_e4m3fn else 0.0 + out, lse = fmha_v2_prefill_deepseek( + q, + k, + v, + o, + num_heads, + head_dim_qk, + seq_len, + scale_softmax=scale_softmax, + scale_bmm1=scale_bmm1, + scale_bmm2=scale_bmm2, + return_lse=True, + lse=lse, + ) + # implementation gives [max(s_i), sum(exp(s_i - max(s_i)))], compute lse from this + if qkv_dtype == torch.float8_e4m3fn: + # For E4M3 the softmax is scaled by 256 (the largest power-of-2 below E4M3_MAX=448.0) + descale = 256 + lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1] / descale) + else: + lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1]) + + if qkv_dtype == torch.float8_e4m3fn: + q_32 = q.to(torch.float32) * q_scale + k_32 = k.to(torch.float32) * k_scale + v_32 = v.to(torch.float32) * v_scale + out_ref, lse_ref = attention_mla_ref_torch( + batch_size, q_32, k_32, v_32, causal=True, sm_scale=sm_scale + ) + else: + out_ref, lse_ref = attention_mla_ref_torch( + batch_size, q, k, v, causal=True, sm_scale=sm_scale + ) + out_ref = out_ref.to(o.dtype) + + if q.dtype == torch.float8_e4m3fn and o.dtype == torch.bfloat16: + rtol, atol = 4e-2, 6e-2 + torch.testing.assert_close(out, out_ref.to(o.dtype), rtol=rtol, atol=atol) + elif q.dtype == torch.bfloat16 and o.dtype == torch.bfloat16: + rtol, atol = 1e-2, 1e-2 + torch.testing.assert_close(out, out_ref, rtol=rtol, atol=atol) + else: + rtol, atol = 1e-2, 1e-3 + + torch.testing.assert_close(lse, lse_ref, rtol=1e-2, atol=1e-3) + + +def run_trtllm_fmha_v2_prefill_case( + input_layout: str, + batch_size: int, + max_seq_len: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: Optional[int], + dtype: torch.dtype, + o_dtype: torch.dtype, + causal: bool, + mask_mode: str, + window_left: int, + logits_soft_cap: float, + pos_encoding_mode: Optional[str], + save_softmax_stats: bool, + skip_softmax_threshold_scale_factor: float, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + from flashinfer.prefill import trtllm_fmha_v2_prefill + from flashinfer.utils import is_sm90a_supported + + if not is_sm90a_supported(torch.device("cuda")) and not is_sm120a_supported( + torch.device("cuda") + ): + pytest.skip("FMHA v2 requires SM90+ (Hopper) or SM12x GPUs.") + + # Skip invalid combinations + is_sm120_plus = is_sm120a_supported(torch.device("cuda")) + if dtype == torch.float8_e4m3fn and is_sm120_plus: + pytest.skip("FP8 FMHAv2 not yet supported on SM120+") + if input_layout == "SEPARATE_Q_K_V" and dtype == torch.float8_e4m3fn: + pytest.skip("FP8 not supported for SEPARATE_Q_K_V layout") + if input_layout == "SEPARATE_Q_K_V" and is_sm120_plus: + pytest.skip( + "SEPARATE_Q_K_V requires SM90 warp-specialization, not available on SM120+" + ) + if ( + is_sm120_plus + and mask_mode is not None + and mask_mode.upper() == "SLIDING_WINDOW" + ): + pytest.skip("SLIDING_WINDOW mask not yet supported on SM120+ (only causal)") + if input_layout == "SEPARATE_Q_K_V" and logits_soft_cap > 0: + pytest.skip("Logits soft capping not supported for SEPARATE_Q_K_V layout") + # save_softmax_stats only supported for CONTIGUOUS_Q_KV (normal attention) + if save_softmax_stats and input_layout != "CONTIGUOUS_Q_KV": + pytest.skip( + "For normal attention, Only CONTIGUOUS_Q_KV layout supports " + "save_softmax_stats. For MLA only SEPARATE_Q_K_V layout supports " + "save_softmax_stats." + ) + if skip_softmax_threshold_scale_factor > 0 and not is_sm90a_supported( + torch.device("cuda") + ): + pytest.skip("Skip softmax attention is only supported on SM90+ (Hopper) GPUs.") + + torch.manual_seed(42) + device = torch.device("cuda") + + seq_lens = torch.randint( + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + dtype=torch.int32, + device=device, + ) + max_kv_len = seq_lens.max().item() + cum_seq_lens = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0) + total_tokens = cum_seq_lens[-1].item() + + sm_scale = 1.0 / math.sqrt(head_dim) + max_q_len = seq_lens.max().item() + block_tables = None + + # --- Create inputs and scales per layout --- + if input_layout == "PACKED_QKV": + if dtype == torch.float8_e4m3fn: + packed_bf16 = torch.randn( + total_tokens, + 3, + num_qo_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + packed_qkv, qkv_scale = to_float8(packed_bf16, dtype=torch.float8_e4m3fn) + qkv_scale = qkv_scale.item() + else: + packed_qkv = torch.randn( + total_tokens, + 3, + num_qo_heads, + head_dim, + dtype=dtype, + device=device, + ) + qkv_scale = 1.0 + qkv_arg = packed_qkv + q_scale, k_scale, v_scale = qkv_scale, qkv_scale, qkv_scale + + elif input_layout == "SEPARATE_Q_K_V": + q = torch.randn( + total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device + ) + k = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device + ) + v = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device + ) + qkv_arg = (q, k, v) + q_scale, k_scale, v_scale = 1.0, 1.0, 1.0 + + elif input_layout == "CONTIGUOUS_Q_KV": + if dtype == torch.float8_e4m3fn: + q_bf16 = torch.randn( + total_tokens, + num_qo_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + q, q_scale = to_float8(q_bf16, dtype=torch.float8_e4m3fn) + q_scale = q_scale.item() + kv_bf16 = torch.randn( + total_tokens, + 2, + num_kv_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + kv, kv_scale = to_float8(kv_bf16, dtype=torch.float8_e4m3fn) + kv_scale = kv_scale.item() + else: + q = torch.randn( + total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device + ) + kv = torch.randn( + total_tokens, + 2, + num_kv_heads, + head_dim, + dtype=dtype, + device=device, + ) + q_scale, kv_scale = 1.0, 1.0 + qkv_arg = (q, kv) + k_scale, v_scale = kv_scale, kv_scale + + elif input_layout in ("Q_PAGED_KV_NHD", "Q_PAGED_KV_HND"): + is_nhd = input_layout == "Q_PAGED_KV_NHD" + max_num_blocks = (max_kv_len + page_size - 1) // page_size + num_pages = batch_size * max_num_blocks + # NHD: [num_pages, 2, page_size, num_kv_heads, head_dim] + # HND: [num_pages, 2, num_kv_heads, page_size, head_dim] + paged_shape = ( + (num_pages, 2, page_size, num_kv_heads, head_dim) + if is_nhd + else (num_pages, 2, num_kv_heads, page_size, head_dim) + ) + if dtype == torch.float8_e4m3fn: + paged_bf16 = torch.randn(*paged_shape, dtype=torch.bfloat16, device=device) + paged_kv_cache, kv_scale = to_float8(paged_bf16, dtype=torch.float8_e4m3fn) + kv_scale = kv_scale.item() + q_bf16 = torch.randn( + total_tokens, + num_qo_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + q, q_scale = to_float8(q_bf16, dtype=torch.float8_e4m3fn) + q_scale = q_scale.item() + else: + paged_kv_cache = torch.randn(*paged_shape, dtype=dtype, device=device) + q = torch.randn( + total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device + ) + q_scale, kv_scale = 1.0, 1.0 + block_tables = torch.zeros( + batch_size, + max_num_blocks, + dtype=torch.int32, + device=device, + ) + for i in range(batch_size): + num_blocks_needed = (seq_lens[i].item() + page_size - 1) // page_size + block_tables[i, :num_blocks_needed] = torch.arange( + i * max_num_blocks, + i * max_num_blocks + num_blocks_needed, + device=device, + ) + qkv_arg = (q, paged_kv_cache) + k_scale, v_scale = kv_scale, kv_scale + + # attention_ref_torch expects NHD paged KV cache; for HND, transpose back + if input_layout == "Q_PAGED_KV_HND": + ref_qkv_arg = (q, paged_kv_cache.transpose(-3, -2).contiguous()) + else: + ref_qkv_arg = qkv_arg + + # --- Compute BMM scales --- + if dtype == torch.float8_e4m3fn: + bmm1_scale = sm_scale * q_scale * k_scale + bmm2_scale = v_scale + else: + bmm1_scale = sm_scale + bmm2_scale = 1.0 + + o = torch.zeros(total_tokens, num_qo_heads, head_dim, dtype=o_dtype, device=device) + workspace_buffer = _get_workspace_buffer() + + # --- Run kernel --- + result = trtllm_fmha_v2_prefill( + qkv_arg, + input_layout, + workspace_buffer=workspace_buffer, + seq_lens=seq_lens, + max_q_len=max_q_len, + max_kv_len=max_kv_len, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + batch_size=batch_size, + cum_seq_lens_q=cum_seq_lens, + cum_seq_lens_kv=cum_seq_lens, + block_tables=block_tables, + out=o, + out_dtype=o_dtype, + mask_mode=mask_mode, + window_left=window_left, + logits_soft_cap_scale=logits_soft_cap if logits_soft_cap > 0 else None, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + pos_encoding_mode=pos_encoding_mode, + save_softmax_stats=save_softmax_stats, + ) + + if save_softmax_stats: + output, kernel_lse = result + else: + output = result + + # --- Reference --- + ref_result = attention_ref_torch( + ref_qkv_arg, + seq_lens=seq_lens, + cum_seq_lens_q=cum_seq_lens, + sm_scale=sm_scale, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + causal=causal, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + block_tables=block_tables, + return_lse=save_softmax_stats, + ) + + if save_softmax_stats: + output_ref, lse_ref = ref_result + else: + output_ref = ref_result + + if dtype == torch.float8_e4m3fn: + if o_dtype == torch.float8_e4m3fn: + default_rtol, default_atol = 1e-1, 1e-1 + else: + default_rtol, default_atol = 4e-2, 8e-2 + else: + default_rtol, default_atol = 1e-2, 1e-2 + + output_rtol = default_rtol if rtol is None else rtol + output_atol = default_atol if atol is None else atol + torch.testing.assert_close( + output.float(), output_ref.float(), rtol=output_rtol, atol=output_atol + ) + + if save_softmax_stats: + # kernel_lse: [total_tokens, num_qo_heads, 2] -> [max, sum_exp] in ragged format + # + # The stored max format differs by architecture: + # - SM90 (Hopper) warp-spec: Softmax_saver_tma stores max / sqrt(head_dim), + # and uses exp2f with scale_bmm1 * M_LOG2E. The max tracks raw (unscaled) + # QK^T values, so max / sqrt(d) == max * sm_scale. + # - SM120 (Blackwell) tiled: Softmax_saver stores max directly (no sqrt(d) + # division). Elements are pre-scaled by scale_bmm1 during unpack, so + # the max already includes the sm_scale factor. + # + # In both cases, for non-softcap: + # lse = kernel_max * q_scale * k_scale + ln(sum_exp) + # + # For softcap: + # - SM90: max = max(softcapped) / sqrt(d), so lse = kernel_max / sm_scale + ln(sum) + # - SM120: max = max(softcapped) directly, so lse = kernel_max + ln(sum) + kernel_max = kernel_lse[:, :, 0] + kernel_sum_exp = kernel_lse[:, :, 1] + is_sm12x = is_sm12x_supported(torch.device("cuda")) + if logits_soft_cap > 0: + if is_sm12x: + # SM120 Softmax_saver: max stored directly (no sqrt(d) division) + lse_kernel = kernel_max + torch.log(kernel_sum_exp) + else: + # SM90 Softmax_saver_tma: max stored as max / sqrt(d) + lse_kernel = kernel_max / sm_scale + torch.log(kernel_sum_exp) + else: + lse_kernel = kernel_max * (q_scale * k_scale) + torch.log(kernel_sum_exp) + torch.testing.assert_close(lse_kernel, lse_ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 16]) +@pytest.mark.parametrize("max_seq_len", [1024]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize( + ("dtype", "o_dtype"), + [ + (torch.float16, torch.float16), + (torch.bfloat16, torch.bfloat16), + (torch.float8_e4m3fn, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.bfloat16), + (torch.float8_e4m3fn, torch.float16), + ], +) +@pytest.mark.parametrize( + ("input_layout", "page_size", "save_softmax_stats"), + [ + ("PACKED_QKV", None, False), + ("CONTIGUOUS_Q_KV", None, False), + ("CONTIGUOUS_Q_KV", None, True), + ("SEPARATE_Q_K_V", None, False), + ("Q_PAGED_KV_NHD", 32, False), + ("Q_PAGED_KV_NHD", 128, False), + ("Q_PAGED_KV_HND", 32, False), + ("Q_PAGED_KV_HND", 128, False), + ], +) +@pytest.mark.parametrize( + ("causal", "window_left", "mask_mode"), + [ + (True, -1, "CAUSAL"), + (True, 127, "SLIDING_WINDOW"), + (True, 512, "SLIDING_WINDOW"), + ], +) +@pytest.mark.parametrize("pos_encoding_mode", [None]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_trtllm_fmha_v2_prefill( + input_layout: str, + batch_size: int, + max_seq_len: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: Optional[int], + dtype: torch.dtype, + o_dtype: torch.dtype, + causal: bool, + mask_mode: str, + window_left: int, + logits_soft_cap: float, + pos_encoding_mode: str, + save_softmax_stats: bool, +) -> None: + # skip bs=16, q_heads=4, kv_heads=4, head_dim=256, dtype=float8_e4m3fn if packed/contiguous and sliding window due to bug + if ( + batch_size == 16 + and num_kv_heads == 4 + and head_dim == 256 + and dtype == torch.float8_e4m3fn + and input_layout in ["PACKED_QKV", "CONTIGUOUS_Q_KV"] + and mask_mode == "SLIDING_WINDOW" + ): + pytest.skip("Skip due to bug in fp8 sliding window") + run_trtllm_fmha_v2_prefill_case( + input_layout=input_layout, + batch_size=batch_size, + max_seq_len=max_seq_len, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + dtype=dtype, + o_dtype=o_dtype, + causal=causal, + mask_mode=mask_mode, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + pos_encoding_mode=pos_encoding_mode, + save_softmax_stats=save_softmax_stats, + skip_softmax_threshold_scale_factor=0.0, + ) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("max_seq_len", [16384]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize( + ("dtype", "o_dtype"), + [ + (torch.float16, torch.float16), + (torch.bfloat16, torch.bfloat16), + (torch.float8_e4m3fn, torch.bfloat16), + ], +) +@pytest.mark.parametrize( + "input_layout", ["CONTIGUOUS_Q_KV", "Q_PAGED_KV_NHD", "Q_PAGED_KV_HND"] +) +@pytest.mark.parametrize( + ( + "skip_softmax_threshold_scale_factor", + "rtol", + "atol", + ), + [ + (500.0, 2e-2, 1.2e-1), + (10000.0, 2e-2, 2e-1), + ], +) +def test_trtllm_fmha_v2_prefill_skip_softmax( + input_layout: str, + batch_size: int, + max_seq_len: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + o_dtype: torch.dtype, + skip_softmax_threshold_scale_factor: float, + rtol: float, + atol: float, +) -> None: + run_trtllm_fmha_v2_prefill_case( + input_layout=input_layout, + batch_size=batch_size, + max_seq_len=max_seq_len, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=32, + dtype=dtype, + o_dtype=o_dtype, + causal=True, + mask_mode="CAUSAL", + window_left=-1, + logits_soft_cap=0.0, + pos_encoding_mode=None, + save_softmax_stats=False, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.parametrize("batch_size", [4, 16]) +@pytest.mark.parametrize("max_seq_len", [1024, 4096]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + ("causal", "window_left", "mask_mode"), + [ + (True, -1, "CAUSAL"), + (True, 127, "SLIDING_WINDOW"), + (True, 512, "SLIDING_WINDOW"), + ], +) +@pytest.mark.parametrize("pos_encoding_mode", [None]) +def test_trtllm_fmha_v2_prefill_attention_sinks( + batch_size: int, + max_seq_len: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + causal: bool, + window_left: int, + mask_mode: str, + pos_encoding_mode: str, +) -> None: + """ + Test trtllm_fmha_v2_prefill with attention sinks. + Compares against BatchAttentionWithAttentionSinkWrapper as reference. + """ + from flashinfer.prefill import trtllm_fmha_v2_prefill + from flashinfer.utils import is_sm90a_supported + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FMHA v2 requires SM90+ (Hopper) GPUs.") + + torch.manual_seed(42) + device = torch.device("cuda") + + seq_lens = torch.randint( + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + dtype=torch.int32, + device=device, + ) + max_kv_len = seq_lens.max().item() + cum_seq_lens = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0) + total_tokens = cum_seq_lens[-1].item() + + # Create separate Q, K, V tensors + q = torch.randn(total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device) + k = torch.randn(total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device) + o = torch.zeros(total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device) + workspace_buffer = _get_workspace_buffer() + + sm_scale = 1.0 / math.sqrt(head_dim) + max_q_len = max_kv_len + + # Create sink tensor with random values to properly test the feature + sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 + + # Test trtllm_fmha_v2_prefill with sinks parameter + output = trtllm_fmha_v2_prefill( + (q, k, v), + "SEPARATE_Q_K_V", + workspace_buffer=workspace_buffer, + seq_lens=seq_lens, + max_q_len=max_q_len, + max_kv_len=max_kv_len, + bmm1_scale=sm_scale, + bmm2_scale=1.0, + batch_size=batch_size, + cum_seq_lens_q=cum_seq_lens, + cum_seq_lens_kv=cum_seq_lens, + out=o, + out_dtype=dtype, + sinks=sink, + mask_mode=mask_mode, + window_left=window_left, + pos_encoding_mode=pos_encoding_mode, + ) + + # Reference: use BatchAttentionWithAttentionSinkWrapper + workspace_buffer_ref = torch.empty( + 128 * 1024 * 1024, dtype=torch.uint8, device=device + ) + + # cumulative token counts per sequence + kv_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(seq_lens, dim=0, dtype=torch.int32), + ] + ) + + # Create kv_indices for page_size=1 (each token is a page) + kv_indices = torch.arange(0, total_tokens, dtype=torch.int32, device=device) + paged_kv_last_page_len = torch.full( + (batch_size,), 1, dtype=torch.int32, device=device + ) + + wrapper_ref = flashinfer.BatchAttentionWithAttentionSinkWrapper( + workspace_buffer_ref, + kv_layout="NHD", + backend="fa3", + q_data_type=dtype, + kv_data_type=dtype, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + window_left=window_left, + ) + wrapper_ref.plan( + cum_seq_lens.cpu(), + kv_indptr.cpu(), + kv_indices.cpu(), + paged_kv_last_page_len.cpu(), + num_qo_heads, + num_kv_heads, + head_dim, + 1, # page_size + causal=causal, + window_left=window_left, + q_data_type=dtype, + kv_data_type=dtype, + ) + output_ref = wrapper_ref.run(q, (k, v), sink, sm_scale) + + rtol, atol = 1e-2, 1e-2 + torch.testing.assert_close(output.float(), output_ref.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("max_seq_len", [1024, 4096]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + ("input_layout", "page_size"), + [ + ("CONTIGUOUS_Q_KV", None), + ("SEPARATE_Q_K_V", None), + ("Q_PAGED_KV_NHD", 32), + ], +) +@pytest.mark.parametrize("chunked_attention_size", [64, 256]) +def test_trtllm_fmha_v2_prefill_chunked_attention( + batch_size: int, + max_seq_len: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + input_layout: str, + page_size: Optional[int], + chunked_attention_size: int, +) -> None: + """Test trtllm_fmha_v2_prefill with chunked attention mask against a + pure-PyTorch reference that builds the exact same mask pattern. + + Chunked attention divides the KV-space into fixed-size, non-overlapping + chunks. Each query token can only attend causally within its chunk. + """ + from flashinfer.prefill import trtllm_fmha_v2_prefill + from flashinfer.utils import is_sm90a_supported + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FMHA v2 requires SM90+ (Hopper) GPUs.") + + torch.manual_seed(42) + device = torch.device("cuda") + + # --- Generate per-sequence lengths --- + seq_lens = torch.randint( + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + dtype=torch.int32, + device=device, + ) + max_kv_len = seq_lens.max().item() + cum_seq_lens = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0) + total_tokens = cum_seq_lens[-1].item() + + sm_scale = 1.0 / math.sqrt(head_dim) + block_tables = None + + # --- Create inputs per layout --- + if input_layout == "SEPARATE_Q_K_V": + q = torch.randn( + total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device + ) + k = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device + ) + v = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device + ) + qkv_arg = (q, k, v) + + elif input_layout == "CONTIGUOUS_Q_KV": + q = torch.randn( + total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device + ) + kv = torch.randn( + total_tokens, 2, num_kv_heads, head_dim, dtype=dtype, device=device + ) + qkv_arg = (q, kv) + + elif input_layout == "Q_PAGED_KV_NHD": + max_num_blocks = (max_kv_len + page_size - 1) // page_size + num_pages = batch_size * max_num_blocks + paged_kv_cache = torch.randn( + num_pages, + 2, + page_size, + num_kv_heads, + head_dim, + dtype=dtype, + device=device, + ) + q = torch.randn( + total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device + ) + block_tables = torch.zeros( + batch_size, + max_num_blocks, + dtype=torch.int32, + device=device, + ) + for i in range(batch_size): + num_blocks_needed = (seq_lens[i].item() + page_size - 1) // page_size + block_tables[i, :num_blocks_needed] = torch.arange( + i * max_num_blocks, + i * max_num_blocks + num_blocks_needed, + device=device, + ) + qkv_arg = (q, paged_kv_cache) + else: + raise ValueError(f"Unsupported input_layout: {input_layout}") + + bmm1_scale = sm_scale + bmm2_scale = 1.0 + + workspace_buffer = _get_workspace_buffer() + + # --- Run kernel with chunked attention --- + output = trtllm_fmha_v2_prefill( + qkv_arg, + input_layout, + workspace_buffer=workspace_buffer, + seq_lens=seq_lens, + max_q_len=max_kv_len, + max_kv_len=max_kv_len, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + batch_size=batch_size, + cum_seq_lens_q=cum_seq_lens, + cum_seq_lens_kv=cum_seq_lens, + block_tables=block_tables, + out_dtype=dtype, + mask_mode="chunked", + chunked_attention_size=chunked_attention_size, + ) + + # --- Reference --- + output_ref = chunked_attention_ref_torch( + qkv_arg, + seq_lens=seq_lens, + cum_seq_lens_q=cum_seq_lens, + sm_scale=sm_scale, + chunked_attention_size=chunked_attention_size, + block_tables=block_tables, + ) + + rtol, atol = 1e-2, 1e-2 + torch.testing.assert_close(output.float(), output_ref.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("max_kv_len", [1024, 4096]) +@pytest.mark.parametrize("max_new_tokens", [64, 256]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [32, 128]) +@pytest.mark.parametrize("chunked_attention_size", [64, 256]) +def test_trtllm_fmha_v2_chunked_prefill_chunked_attention( + batch_size: int, + max_kv_len: int, + max_new_tokens: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + page_size: int, + chunked_attention_size: int, +) -> None: + """Test chunked prefill (Q < KV) with chunked attention mask on paged KV cache. + + Each sequence has kv_len total tokens in the paged KV cache but only q_len + new query tokens. The chunked attention mask divides the KV-space into + non-overlapping blocks of ``chunked_attention_size`` and restricts each + query to attend causally within its own block. + + Uses Q_PAGED_KV_NHD layout only. The non-paged layouts (CONTIGUOUS_Q_KV, + SEPARATE_Q_K_V) have a known kernel issue with chunked attention mask when + past_kv_length >= STEP_Q (64). + """ + from flashinfer.prefill import trtllm_fmha_v2_prefill + from flashinfer.utils import is_sm90a_supported + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FMHA v2 requires SM90+ (Hopper) GPUs.") + + torch.manual_seed(42) + device = torch.device("cuda") + + # --- Generate per-sequence KV and Q lengths (Q < KV) --- + kv_seq_lens = torch.randint( + max_kv_len // 2, + max_kv_len + 1, + (batch_size,), + dtype=torch.int32, + device=device, + ) + q_seq_lens = torch.randint( + max(1, max_new_tokens // 2), + max_new_tokens + 1, + (batch_size,), + dtype=torch.int32, + device=device, + ) + q_seq_lens = torch.minimum(q_seq_lens, kv_seq_lens) + + actual_max_kv_len = kv_seq_lens.max().item() + actual_max_q_len = q_seq_lens.max().item() + + cum_seq_lens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cum_seq_lens_kv[1:] = torch.cumsum(kv_seq_lens, dim=0) + + cum_seq_lens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cum_seq_lens_q[1:] = torch.cumsum(q_seq_lens, dim=0) + total_q_tokens = cum_seq_lens_q[-1].item() + + sm_scale = 1.0 / math.sqrt(head_dim) + + # --- Create paged KV cache and query --- + max_num_blocks = (actual_max_kv_len + page_size - 1) // page_size + num_pages = batch_size * max_num_blocks + paged_kv_cache = torch.randn( + num_pages, + 2, + page_size, + num_kv_heads, + head_dim, + dtype=dtype, + device=device, + ) + q = torch.randn(total_q_tokens, num_qo_heads, head_dim, dtype=dtype, device=device) + block_tables = torch.zeros( + batch_size, + max_num_blocks, + dtype=torch.int32, + device=device, + ) + for i in range(batch_size): + num_blocks_needed = (kv_seq_lens[i].item() + page_size - 1) // page_size + block_tables[i, :num_blocks_needed] = torch.arange( + i * max_num_blocks, + i * max_num_blocks + num_blocks_needed, + device=device, + ) + qkv_arg = (q, paged_kv_cache) + + workspace_buffer = _get_workspace_buffer() + + # --- Run kernel --- + output = trtllm_fmha_v2_prefill( + qkv_arg, + "Q_PAGED_KV_NHD", + workspace_buffer=workspace_buffer, + seq_lens=kv_seq_lens, + max_q_len=actual_max_q_len, + max_kv_len=actual_max_kv_len, + bmm1_scale=sm_scale, + bmm2_scale=1.0, + batch_size=batch_size, + cum_seq_lens_q=cum_seq_lens_q, + cum_seq_lens_kv=cum_seq_lens_kv, + block_tables=block_tables, + out_dtype=dtype, + mask_mode="chunked", + chunked_attention_size=chunked_attention_size, + ) + + # --- Reference --- + output_ref = chunked_attention_ref_torch( + qkv_arg, + seq_lens=kv_seq_lens, + cum_seq_lens_q=cum_seq_lens_q, + sm_scale=sm_scale, + chunked_attention_size=chunked_attention_size, + cum_seq_lens_kv=cum_seq_lens_kv, + block_tables=block_tables, + ) + + rtol, atol = 1e-2, 1e-2 + torch.testing.assert_close(output.float(), output_ref.float(), rtol=rtol, atol=atol) diff --git a/tests/attention/test_fmha_v2_prefill_deepseek.py b/tests/attention/test_fmha_v2_prefill_deepseek.py deleted file mode 100755 index 3d0b7bd865..0000000000 --- a/tests/attention/test_fmha_v2_prefill_deepseek.py +++ /dev/null @@ -1,170 +0,0 @@ -import pytest -import torch -import math - - -from flashinfer.prefill import fmha_v2_prefill_deepseek -from tests.utils_fp8 import to_float8 -from flashinfer.utils import is_sm12x_supported - - -def attention_ref( - batch_size, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - causal: bool, - sm_scale: float, -) -> torch.Tensor: - # tensors are (batch_size, seq_len, num_heads, head_dim) - qo_len = q.shape[1] - kv_len = k.shape[1] - logits = torch.einsum("bmhd,bnhd->bhmn", q.float(), k.float()) * sm_scale - - if causal: - mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( - 1 - ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) - else: - mask = torch.ones(qo_len, kv_len, device=q.device) - - logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) - # LSE computation: logsumexp over the key dimension (last dim) - # logits shape: (batch, num_heads, seq_len, seq_len) - lse_ref = torch.logsumexp(logits, -1) # (batch, num_heads, seq_len) - # Transpose to match expected shape (batch, seq_len, num_heads) - lse_ref = lse_ref.transpose(1, 2) - p = torch.softmax(logits, dim=-1) - o_ref = torch.einsum("bhmn,bnhd->bmhd", p, v.float()).contiguous() - - # Return LSE in natural log (no conversion needed) - return o_ref, lse_ref - - -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("num_heads", [8]) -@pytest.mark.parametrize("head_dim_qk", [192]) -@pytest.mark.parametrize("head_dim_v", [128]) -@pytest.mark.parametrize("seq_len", [1024, 4096, 8192]) -@pytest.mark.parametrize( - "qkv_dtype,o_dtype", - [ - (torch.bfloat16, torch.bfloat16), - (torch.float8_e4m3fn, torch.bfloat16), - (torch.float8_e4m3fn, torch.float16), - ], -) -def test_fmha_v2_prefill_deepseek( - batch_size, num_heads, head_dim_qk, head_dim_v, seq_len, qkv_dtype, o_dtype -): - if not is_sm12x_supported(torch.device("cuda")): - pytest.skip("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") - torch.manual_seed(42) - - def initialize_tensors(batch_size, num_heads, head_dim_qk, head_dim_v, seq_len): - device = "cuda" - if qkv_dtype == torch.float8_e4m3fn: - q = torch.randn( - (batch_size, seq_len, num_heads, head_dim_qk), - dtype=torch.bfloat16, - device=device, - ) - k = torch.randn( - (batch_size, seq_len, num_heads, head_dim_qk), - dtype=torch.bfloat16, - device=device, - ) - v = torch.randn( - (batch_size, seq_len, num_heads, head_dim_v), - dtype=torch.bfloat16, - device=device, - ) - - q, q_scale = to_float8(q, dtype=torch.float8_e4m3fn) - k, k_scale = to_float8(k, dtype=torch.float8_e4m3fn) - v, v_scale = to_float8(v, dtype=torch.float8_e4m3fn) - q_scale = q_scale.item() - k_scale = k_scale.item() - v_scale = v_scale.item() - else: - q = torch.randn( - (batch_size, seq_len, num_heads, head_dim_qk), - dtype=qkv_dtype, - device=device, - ) - k = torch.randn( - (batch_size, seq_len, num_heads, head_dim_qk), - dtype=qkv_dtype, - device=device, - ) - v = torch.randn( - (batch_size, seq_len, num_heads, head_dim_v), - dtype=qkv_dtype, - device=device, - ) - # For non-FP8 case, scales are 1.0 - q_scale = 1.0 - k_scale = 1.0 - v_scale = 1.0 - - # Output and statistics - o = torch.zeros( - batch_size, seq_len, num_heads, head_dim_v, dtype=o_dtype, device=device - ) - lse = torch.zeros( - batch_size, seq_len, num_heads, 2, dtype=torch.float, device=device - ) - sm_scale = 1.0 / math.sqrt(head_dim_qk) - return q, k, v, o, lse, sm_scale, q_scale, k_scale, v_scale - - q, k, v, o, lse, sm_scale, q_scale, k_scale, v_scale = initialize_tensors( - batch_size, num_heads, head_dim_qk, head_dim_v, seq_len - ) - scale_bmm1 = q_scale * k_scale * sm_scale - scale_bmm2 = v_scale - scale_softmax = 1.0 if qkv_dtype == torch.float8_e4m3fn else 0.0 - out, lse = fmha_v2_prefill_deepseek( - q, - k, - v, - o, - num_heads, - head_dim_qk, - seq_len, - scale_softmax=scale_softmax, - scale_bmm1=scale_bmm1, - scale_bmm2=scale_bmm2, - return_lse=True, - lse=lse, - ) - # implementation gives [max(s_i), sum(exp(s_i - max(s_i)))], compute lse from this - if qkv_dtype == torch.float8_e4m3fn: - # For E4M3 the softmax is scaled by 256 (the largest power-of-2 below E4M3_MAX=448.0) - descale = 256 - lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1] / descale) - else: - lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1]) - - if qkv_dtype == torch.float8_e4m3fn: - q_32 = q.to(torch.float32) * q_scale - k_32 = k.to(torch.float32) * k_scale - v_32 = v.to(torch.float32) * v_scale - out_ref, lse_ref = attention_ref( - batch_size, q_32, k_32, v_32, causal=True, sm_scale=sm_scale - ) - else: - out_ref, lse_ref = attention_ref( - batch_size, q, k, v, causal=True, sm_scale=sm_scale - ) - out_ref = out_ref.to(o.dtype) - - if q.dtype == torch.float8_e4m3fn and o.dtype == torch.bfloat16: - rtol, atol = 4e-2, 6e-2 - torch.testing.assert_close(out, out_ref.to(o.dtype), rtol=rtol, atol=atol) - elif q.dtype == torch.bfloat16 and o.dtype == torch.bfloat16: - rtol, atol = 1e-2, 1e-2 - torch.testing.assert_close(out, out_ref, rtol=rtol, atol=atol) - else: - rtol, atol = 1e-2, 1e-3 - - torch.testing.assert_close(lse, lse_ref, rtol=1e-2, atol=1e-3)