From 3617b4e00d9cb79bd0e2f7540383cc08f54bfde9 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 5 Mar 2026 11:21:36 +0000 Subject: [PATCH 1/7] Add flashinfer fused moe kernels --- Cargo.toml | 4 +- src/flashinfer.rs | 6 +- src/kernels/Cargo.toml | 4 +- src/kernels/build.rs | 20 +- src/kernels/src/ffi.rs | 37 + src/kernels/src/flashinfer_adapter.cu | 6 +- src/kernels/src/flashinfer_moe_adapter.cu | 497 ++++++++ .../src/trtllm/trtllm_batched_gemm_runner.cu | 461 ++++++++ .../src/trtllm/trtllm_cutlass_heuristic.cpp | 776 +++++++++++++ .../src/trtllm/trtllm_fused_moe_dev_kernel.cu | 1011 +++++++++++++++++ .../trtllm_fused_moe_routing_deepseek.cu | 665 +++++++++++ .../trtllm/trtllm_fused_moe_routing_llama4.cu | 581 ++++++++++ .../trtllm_fused_moe_routing_renormalize.cu | 509 +++++++++ .../src/trtllm/trtllm_fused_moe_runner.cu | 566 +++++++++ src/lib.rs | 2 +- src/moe.rs | 241 ++++ 16 files changed, 5373 insertions(+), 13 deletions(-) create mode 100644 src/kernels/src/flashinfer_moe_adapter.cu create mode 100644 src/kernels/src/trtllm/trtllm_batched_gemm_runner.cu create mode 100644 src/kernels/src/trtllm/trtllm_cutlass_heuristic.cpp create mode 100644 src/kernels/src/trtllm/trtllm_fused_moe_dev_kernel.cu create mode 100644 src/kernels/src/trtllm/trtllm_fused_moe_routing_deepseek.cu create mode 100644 src/kernels/src/trtllm/trtllm_fused_moe_routing_llama4.cu create mode 100644 src/kernels/src/trtllm/trtllm_fused_moe_routing_renormalize.cu create mode 100644 src/kernels/src/trtllm/trtllm_fused_moe_runner.cu diff --git a/Cargo.toml b/Cargo.toml index f2cc12f..882b16c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "attention-rs" -version = "0.4.1" +version = "0.4.2" edition = "2021" description = "High-performance LLM attention kernels and operations (PagedAttention, Flahinfer, Mamba, MoE, RoPE) for Candle, optimized for CUDA and Metal." repository = "https://github.com/guoqingbao/attention.rs" @@ -21,7 +21,7 @@ half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_di tracing = "0.1.40" parking_lot = "0.12.4" rayon="1.10.0" -kernels = { path = "./src/kernels", version="0.4.1", optional = true} +kernels = { path = "./src/kernels", version="0.4.2", optional = true} metal = { version = "0.27.0", features = ["mps"], optional = true } metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true} diff --git a/src/flashinfer.rs b/src/flashinfer.rs index 14d69ea..64ae82f 100644 --- a/src/flashinfer.rs +++ b/src/flashinfer.rs @@ -73,11 +73,11 @@ thread_local! { } fn is_supported_flashinfer_gqa_group_size(group_size: usize) -> bool { - matches!(group_size, 1 | 2 | 3 | 4 | 8 | 16 | 32 | 64) + matches!(group_size, 1 | 2 | 3 | 4 | 6 | 8 | 16 | 32 | 64) } fn is_supported_flashinfer_decode_group_size(group_size: usize) -> bool { - matches!(group_size, 1 | 2 | 3 | 4 | 8 | 16 | 32 | 64) + matches!(group_size, 1 | 2 | 3 | 4 | 6 | 8 | 16 | 32 | 64) } fn is_supported_flashinfer_decode_shape(group_size: usize, head_dim: usize) -> bool { @@ -811,7 +811,7 @@ impl FlashInferPrefill { let group_size = self.num_qo_heads / self.num_kv_heads; if !is_supported_flashinfer_gqa_group_size(group_size) { candle::bail!( - "flashinfer prefill only supports gqa group_size in [1,2,3,4,8,16,32,64], got {}", + "flashinfer prefill only supports gqa group_size in [1,2,3,4,6,8,16,32,64], got {}", group_size ); } diff --git a/src/kernels/Cargo.toml b/src/kernels/Cargo.toml index b2a480a..28e1775 100644 --- a/src/kernels/Cargo.toml +++ b/src/kernels/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "kernels" -version = "0.4.1" +version = "0.4.2" edition = "2021" -description = "Paged attention kernels for Rust" +description = "Attention, MoE and GEMM kernels for Rust" categories = ["science"] license = "MIT OR Apache-2.0" diff --git a/src/kernels/build.rs b/src/kernels/build.rs index e573e8d..2621c77 100644 --- a/src/kernels/build.rs +++ b/src/kernels/build.rs @@ -27,6 +27,14 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=src/fp8_moe_cutlass.cu"); println!("cargo:rerun-if-changed=src/flashinfer_fp8_qquant.cu"); println!("cargo:rerun-if-changed=src/flashinfer_adapter_fp8.cu"); + println!("cargo:rerun-if-changed=src/flashinfer_moe_adapter.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_batched_gemm_runner.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_fused_moe_runner.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_fused_moe_dev_kernel.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_fused_moe_routing_renormalize.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_fused_moe_routing_deepseek.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_fused_moe_routing_llama4.cu"); + println!("cargo:rerun-if-changed=src/trtllm/trtllm_cutlass_heuristic.cpp"); println!("cargo:rerun-if-changed=src/gdn.cu"); let marlin_disabled = std::env::var("CARGO_FEATURE_NO_MARLIN").is_ok(); @@ -68,6 +76,7 @@ fn main() -> Result<()> { builder = builder.arg("-DUSE_CUTLASS").with_cutlass(None); if std::env::var("CARGO_FEATURE_FLASHINFER").is_ok() { + builder = builder.arg("-DENABLE_BF16").arg("-DENABLE_FP8"); if compute_cap >= 89 { builder = builder.arg("-DFLASHINFER_ENABLE_FP8_E8M0"); } @@ -89,8 +98,15 @@ fn main() -> Result<()> { builder = builder.arg("-DUSE_FLASHINFER").with_git_dependency( "flashinfer", "https://github.com/guoqingbao/flashinfer.git", - "960cb902ce15ec085d42aa1bbe7026979c9a04dd", // v0.6.2 - vec!["include"], + "3bffdb76eef5fec462254dde67a7de0c4bcb9905", // v0.6.2 + vec![ + "include", + "include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export", + "include/flashinfer/trtllm/gemm/trtllmGen_gemm_export", + "csrc/nv_internal", + "csrc/nv_internal/include", + "csrc/nv_internal/tensorrt_llm/cutlass_extensions/include", + ], false, ); } diff --git a/src/kernels/src/ffi.rs b/src/kernels/src/ffi.rs index a64d7fa..d528675 100644 --- a/src/kernels/src/ffi.rs +++ b/src/kernels/src/ffi.rs @@ -1106,6 +1106,43 @@ extern "C" { stream: i64, ); + #[cfg(feature = "flashinfer")] + pub fn flashinfer_fused_moe_bf16( + input: *const c_void, + topk_ids: *const i32, + topk_weights: *const f32, + gate_up_weights: *const c_void, + down_weights: *const c_void, + output: *mut c_void, + num_tokens: i32, + hidden_size: i32, + intermediate_size: i32, + num_experts: i32, + top_k: i32, + input_dtype: i32, + weight_dtype: i32, + stream: i64, + ) -> i32; + + #[cfg(feature = "flashinfer")] + pub fn flashinfer_fused_moe_fp8( + input: *const c_void, + topk_ids: *const i32, + topk_weights: *const f32, + gate_up_weights: *const u8, + gate_up_scales: *const f32, + down_weights: *const u8, + down_scales: *const f32, + output: *mut c_void, + num_tokens: i32, + hidden_size: i32, + intermediate_size: i32, + num_experts: i32, + top_k: i32, + input_dtype: i32, + stream: i64, + ) -> i32; + pub fn causal_conv1d_fwd_f32( x: *const f32, weight: *const f32, diff --git a/src/kernels/src/flashinfer_adapter.cu b/src/kernels/src/flashinfer_adapter.cu index 63c815e..065fdbd 100644 --- a/src/kernels/src/flashinfer_adapter.cu +++ b/src/kernels/src/flashinfer_adapter.cu @@ -226,7 +226,7 @@ static inline void FillSM90RaggedParams( #ifdef USE_FLASHINFER static inline bool IsSupportedDecodeGroupSize(uint32_t group_size) { return group_size == 1 || group_size == 2 || group_size == 3 || group_size == 4 || - group_size == 8 || group_size == 16 || group_size == 32 || group_size == 64; + group_size == 6 || group_size == 8 || group_size == 16 || group_size == 32 || group_size == 64; } static inline bool IsSupportedDecodeHeadDimForGroupSize(uint32_t group_size, uint32_t head_dim) { @@ -457,7 +457,7 @@ void flashinfer_decode_plan_wrapper( uint32_t group_size = static_cast(num_qo_heads / num_kv_heads); if (!IsSupportedDecodeGroupSize(group_size)) { fprintf(stderr, - "[flashinfer][decode_plan] unsupported group_size=%u (supported: 1,2,3,4,8,16,32,64)\n", + "[flashinfer][decode_plan] unsupported group_size=%u (supported: 1,2,3,4,6,8,16,32,64)\n", group_size); return; } @@ -600,7 +600,7 @@ void flashinfer_decode_run_wrapper( uint32_t group_size = static_cast(num_qo_heads / num_kv_heads); if (!IsSupportedDecodeGroupSize(group_size)) { fprintf(stderr, - "[flashinfer][decode_run] unsupported group_size=%u (supported: 1,2,3,4,8,16,32,64)\n", + "[flashinfer][decode_run] unsupported group_size=%u (supported: 1,2,3,4,6,8,16,32,64)\n", group_size); return; } diff --git a/src/kernels/src/flashinfer_moe_adapter.cu b/src/kernels/src/flashinfer_moe_adapter.cu new file mode 100644 index 0000000..467b10e --- /dev/null +++ b/src/kernels/src/flashinfer_moe_adapter.cu @@ -0,0 +1,497 @@ +#include + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && \ + __has_include("tensorrt_llm/common/logger.h") + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "flashinfer/trtllm/fused_moe/runner.h" + +namespace trtllm_moe = tensorrt_llm::kernels::trtllmgen_moe; +namespace btg = batchedGemm::trtllm::gen; + +namespace { + +__global__ void pack_topk_to_bf16_packed(const int32_t* topk_ids, const float* topk_weights, + int32_t* packed, int64_t numel, int32_t num_experts) { + int64_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (i >= numel) { + return; + } + + int32_t expert_id = topk_ids[i]; + if (expert_id < 0 || expert_id >= num_experts) { + expert_id = 0; + } + + __nv_bfloat16 score_bf16 = __float2bfloat16(topk_weights[i]); + uint16_t score_bits = *reinterpret_cast(&score_bf16); + uint16_t expert_bits = static_cast(expert_id); + packed[i] = (static_cast(expert_bits) << 16) | static_cast(score_bits); +} + +int32_t choose_tile_tokens_dim(int32_t num_tokens, int32_t top_k, int32_t num_experts) { + float avg_tokens_per_expert = + static_cast(num_tokens) * static_cast(top_k) / std::max(num_experts, 1); + int32_t tile_tokens_dim = 1; + while (tile_tokens_dim < static_cast(avg_tokens_per_expert) && tile_tokens_dim < 128) { + tile_tokens_dim <<= 1; + } + tile_tokens_dim = std::clamp(tile_tokens_dim, 8, 128); + return tile_tokens_dim; +} + +btg::Dtype parse_dtype(int32_t dtype_code) { + switch (dtype_code) { + case 0: + return btg::Dtype::Fp16; + case 1: + return btg::Dtype::Bfloat16; + default: + throw std::runtime_error("Unsupported dtype code for fused moe"); + } +} + +size_t dtype_size(btg::Dtype dtype) { + switch (dtype) { + case btg::Dtype::Fp16: + return sizeof(__half); + case btg::Dtype::Bfloat16: + return sizeof(__nv_bfloat16); + case btg::Dtype::E4m3: + return sizeof(uint8_t); + default: + throw std::runtime_error("Unsupported dtype size query"); + } +} + +struct DeviceBuffer { + void* ptr = nullptr; + size_t bytes = 0; + + void* ensure(size_t required_bytes, cudaStream_t stream) { + if (required_bytes == 0) { + return nullptr; + } + if (ptr != nullptr && bytes >= required_bytes) { + return ptr; + } + if (ptr != nullptr) { + cudaError_t free_err = cudaFreeAsync(ptr, stream); + if (free_err != cudaSuccess) { + throw std::runtime_error(std::string("cudaFreeAsync failed: ") + + cudaGetErrorString(free_err)); + } + ptr = nullptr; + bytes = 0; + } + cudaError_t err = cudaMallocAsync(&ptr, required_bytes, stream); + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaMallocAsync failed: ") + + cudaGetErrorString(err)); + } + bytes = required_bytes; + return ptr; + } + + template + T* ensure_typed(size_t count, cudaStream_t stream) { + return static_cast(ensure(count * sizeof(T), stream)); + } +}; + +struct StreamMoECache { + int device = -1; + cudaStream_t stream = nullptr; + + int routing_tile_tokens_dim = -1; + std::unique_ptr routing_runner; + + int bf16_tile_tokens_dim = -1; + btg::Dtype bf16_input_dtype = btg::Dtype::Fp16; + btg::Dtype bf16_weight_dtype = btg::Dtype::Fp16; + std::unique_ptr bf16_runner; + + int fp8_tile_tokens_dim = -1; + std::unique_ptr fp8_runner; + + DeviceBuffer packed_topk; + DeviceBuffer num_tokens_per_expert; + DeviceBuffer expert_count_histogram; + DeviceBuffer permuted_idx_size; + DeviceBuffer expanded_idx_to_permuted_idx; + DeviceBuffer permuted_idx_to_token_idx; + DeviceBuffer expert_weights; + DeviceBuffer cta_idx_xy_to_batch_idx; + DeviceBuffer cta_idx_xy_to_mn_limit; + DeviceBuffer num_non_exiting_ctas; + + DeviceBuffer bmm1_workspace; + DeviceBuffer bmm2_workspace; + DeviceBuffer gemm1_output; + DeviceBuffer gemm1_output_scale; + DeviceBuffer activation_output; + DeviceBuffer activation_output_scale; + DeviceBuffer gemm2_output; +}; + +uint64_t make_cache_key(int device, cudaStream_t stream) { + uint64_t s = static_cast(reinterpret_cast(stream)); + return (static_cast(static_cast(device)) << 32) ^ s; +} + +StreamMoECache& get_stream_cache(cudaStream_t stream) { + int device = 0; + cudaError_t err = cudaGetDevice(&device); + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaGetDevice failed: ") + + cudaGetErrorString(err)); + } + + static thread_local std::unordered_map> caches; + uint64_t key = make_cache_key(device, stream); + auto it = caches.find(key); + if (it == caches.end()) { + auto cache = std::make_unique(); + cache->device = device; + cache->stream = stream; + it = caches.emplace(key, std::move(cache)).first; + } + return *it->second; +} + +void run_routing_from_precomputed_topk( + const int32_t* topk_ids, const float* topk_weights, int32_t num_tokens, int32_t num_experts, + int32_t top_k, int32_t tile_tokens_dim, btg::Dtype dtype_elt, + trtllm_moe::MoE::MoEWorkspace& workspace, StreamMoECache& cache, cudaStream_t stream) { + int64_t expanded_elems = static_cast(num_tokens) * top_k; + + int32_t* packed_topk = cache.packed_topk.ensure_typed(expanded_elems, stream); + + int threads = 256; + int blocks = static_cast((expanded_elems + threads - 1) / threads); + if (blocks > 0) { + pack_topk_to_bf16_packed<<>>(topk_ids, topk_weights, packed_topk, + expanded_elems, num_experts); + cudaError_t launch_err = cudaGetLastError(); + if (launch_err != cudaSuccess) { + throw std::runtime_error(std::string("pack_topk_to_bf16_packed launch failed: ") + + cudaGetErrorString(launch_err)); + } + } + + int32_t max_num_ctas = trtllm_moe::Routing::getMaxNumCtasInBatchDim(num_tokens, top_k, + num_experts, tile_tokens_dim); + int32_t total_max_padded_tokens = trtllm_moe::Routing::getMaxPermutedPaddedCount( + num_tokens, top_k, num_experts, tile_tokens_dim); + + workspace.total_max_padded_tokens = total_max_padded_tokens; + workspace.ProjUpTileN = tile_tokens_dim; + + int32_t* num_tokens_per_expert = + cache.num_tokens_per_expert.ensure_typed(num_experts, stream); + int32_t* expert_count_histogram = + cache.expert_count_histogram.ensure_typed(std::max(num_experts * 2, 256 * 2), + stream); + + workspace.routing_expert_indexes = packed_topk; + workspace.permuted_idx_size = cache.permuted_idx_size.ensure_typed(1, stream); + workspace.total_num_padded_tokens = workspace.permuted_idx_size; + workspace.expanded_idx_to_permuted_idx = + cache.expanded_idx_to_permuted_idx.ensure_typed(expanded_elems, stream); + workspace.permuted_idx_to_expanded_idx = nullptr; + workspace.permuted_idx_to_token_idx = + cache.permuted_idx_to_token_idx.ensure_typed(total_max_padded_tokens, stream); + workspace.expert_weights = + cache.expert_weights.ensure_typed<__nv_bfloat16>(expanded_elems, stream); + workspace.token_scales = nullptr; + workspace.cta_idx_xy_to_batch_idx = + cache.cta_idx_xy_to_batch_idx.ensure_typed(max_num_ctas, stream); + workspace.cta_idx_xy_to_mn_limit = + cache.cta_idx_xy_to_mn_limit.ensure_typed(max_num_ctas, stream); + workspace.num_non_exiting_ctas = + cache.num_non_exiting_ctas.ensure_typed(1, stream); + + if (!cache.routing_runner || cache.routing_tile_tokens_dim != tile_tokens_dim) { + cache.routing_runner = std::make_unique(tile_tokens_dim); + cache.routing_tile_tokens_dim = tile_tokens_dim; + } + cache.routing_runner->run( + /*routingLogits=*/nullptr, + /*routingBias=*/nullptr, + num_tokens, + num_experts, + top_k, + /*nGroups=*/0, + /*topkGroups=*/0, + /*localExpertOffset=*/0, + /*localNumExperts=*/num_experts, + /*routedScalingFactor=*/1.f, + workspace.routing_expert_indexes, + expert_count_histogram, + workspace.permuted_idx_size, + workspace.expanded_idx_to_permuted_idx, + workspace.permuted_idx_to_expanded_idx, + workspace.permuted_idx_to_token_idx, + workspace.expert_weights, + num_tokens_per_expert, + workspace.cta_idx_xy_to_batch_idx, + workspace.cta_idx_xy_to_mn_limit, + workspace.num_non_exiting_ctas, + dtype_elt, + btg::Dtype::Bfloat16, + /*useRoutingScalesOnInput=*/false, + /*useDeepSeekFp8=*/false, + trtllm_moe::Routing::RoutingMethodType::Renormalize, + stream); +} + +int run_fused_moe_bf16(const void* input, const int32_t* topk_ids, const float* topk_weights, + const void* gate_up_weights, const void* down_weights, void* output, + int32_t num_tokens, int32_t hidden_size, int32_t intermediate_size, + int32_t num_experts, int32_t top_k, int32_t input_dtype_code, + int32_t weight_dtype_code, cudaStream_t stream) { + btg::Dtype input_dtype = parse_dtype(input_dtype_code); + btg::Dtype weight_dtype = parse_dtype(weight_dtype_code); + int32_t tile_tokens_dim = choose_tile_tokens_dim(num_tokens, top_k, num_experts); + + StreamMoECache& cache = get_stream_cache(stream); + trtllm_moe::MoE::MoEWorkspace workspace{}; + + run_routing_from_precomputed_topk(topk_ids, topk_weights, num_tokens, num_experts, top_k, + tile_tokens_dim, input_dtype, workspace, cache, stream); + + if (!cache.bf16_runner || cache.bf16_tile_tokens_dim != tile_tokens_dim || + cache.bf16_input_dtype != input_dtype || cache.bf16_weight_dtype != weight_dtype) { + cache.bf16_runner = std::make_unique( + input_dtype, weight_dtype, /*useDeepSeekFp8=*/false, tile_tokens_dim, + trtllm_moe::MoE::GatedActType::SwiGlu, /*useShuffledMatrixA=*/false, + batchedGemm::gemm::MatrixLayout::MajorK); + cache.bf16_tile_tokens_dim = tile_tokens_dim; + cache.bf16_input_dtype = input_dtype; + cache.bf16_weight_dtype = weight_dtype; + } + + trtllm_moe::MoE::MoERunnerArgs args{}; + args.hidden_states = const_cast(input); + args.gemm1_weights = const_cast(gate_up_weights); + args.gemm2_weights = const_cast(down_weights); + args.output = output; + + args.num_tokens = num_tokens; + args.num_experts = num_experts; + args.hidden_size = hidden_size; + args.intermediate_size = intermediate_size; + args.top_k = top_k; + args.local_expert_offset = 0; + args.local_num_experts = num_experts; + args.mDtypeElt = input_dtype; + args.mDtypeExpW = btg::Dtype::Bfloat16; + args.mDtypeOut = input_dtype; + args.mUseRoutingScalesOnInput = false; + args.mUseDeepSeekFp8 = false; + args.output_scale = nullptr; + args.do_finalize = true; + + int64_t config_idx = + cache.bf16_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + num_experts, num_tokens); + + auto workspace_sizes = cache.bf16_runner->getWorkspaceSizeInBytes(args, config_idx); + workspace.bmm1_workspace = + cache.bmm1_workspace.ensure(static_cast(std::get<0>(workspace_sizes)), stream); + workspace.bmm2_workspace = + cache.bmm2_workspace.ensure(static_cast(std::get<1>(workspace_sizes)), stream); + + int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; + size_t gemm1_elem_size = dtype_size(input_dtype); + size_t gemm2_elem_size = dtype_size(input_dtype); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = cache.gemm1_output.ensure( + static_cast(max_num_padded_tokens) * intermediate_size * gemm1_elem_size, stream); + workspace.gemm1_output_scale = nullptr; + workspace.activation_output = cache.activation_output.ensure( + static_cast(max_num_padded_tokens) * intermediate_size * gemm1_elem_size, stream); + workspace.activation_output_scale = nullptr; + workspace.gemm2_output = cache.gemm2_output.ensure( + static_cast(max_num_padded_tokens) * hidden_size * gemm2_elem_size, stream); + workspace.gemm2_output_scale = nullptr; + cache.bf16_runner->run(args, workspace, cache.device, stream, config_idx, /*enable_pdl=*/true); + return 0; +} + +int run_fused_moe_fp8(const void* input, const int32_t* topk_ids, const float* topk_weights, + const uint8_t* gate_up_weights, const float* gate_up_scales, + const uint8_t* down_weights, const float* down_scales, void* output, + int32_t num_tokens, int32_t hidden_size, int32_t intermediate_size, + int32_t num_experts, int32_t top_k, int32_t input_dtype_code, + cudaStream_t stream) { + btg::Dtype input_dtype = parse_dtype(input_dtype_code); + int32_t tile_tokens_dim = choose_tile_tokens_dim(num_tokens, top_k, num_experts); + + if (hidden_size % 128 != 0 || intermediate_size % 128 != 0) { + throw std::runtime_error("FP8 fused moe requires hidden/intermediate dims divisible by 128"); + } + + StreamMoECache& cache = get_stream_cache(stream); + trtllm_moe::MoE::MoEWorkspace workspace{}; + + run_routing_from_precomputed_topk(topk_ids, topk_weights, num_tokens, num_experts, top_k, + tile_tokens_dim, input_dtype, workspace, cache, stream); + + if (!cache.fp8_runner || cache.fp8_tile_tokens_dim != tile_tokens_dim) { + cache.fp8_runner = std::make_unique( + btg::Dtype::E4m3, /*useDeepSeekFp8=*/true, tile_tokens_dim, + /*useShuffledMatrixA=*/false, batchedGemm::gemm::MatrixLayout::MajorK); + cache.fp8_tile_tokens_dim = tile_tokens_dim; + } + + trtllm_moe::MoE::MoERunnerArgs args{}; + args.hidden_states = const_cast(input); + args.hidden_states_scale = nullptr; + args.gemm1_weights = const_cast(gate_up_weights); + args.gemm1_weights_scale = const_cast(gate_up_scales); + args.gemm2_weights = const_cast(down_weights); + args.gemm2_weights_scale = const_cast(down_scales); + args.output = output; + + args.num_tokens = num_tokens; + args.num_experts = num_experts; + args.hidden_size = hidden_size; + args.intermediate_size = intermediate_size; + args.top_k = top_k; + args.local_expert_offset = 0; + args.local_num_experts = num_experts; + args.mDtypeElt = input_dtype; + args.mDtypeExpW = btg::Dtype::Bfloat16; + args.mDtypeOut = btg::Dtype::Bfloat16; + args.mUseRoutingScalesOnInput = false; + args.mUseDeepSeekFp8 = true; + args.output_scale = nullptr; + args.do_finalize = true; + + int64_t config_idx = + cache.fp8_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + num_experts, num_tokens); + + auto workspace_sizes = cache.fp8_runner->getWorkspaceSizeInBytes(args, config_idx); + workspace.bmm1_workspace = + cache.bmm1_workspace.ensure(static_cast(std::get<0>(workspace_sizes)), stream); + workspace.bmm2_workspace = + cache.bmm2_workspace.ensure(static_cast(std::get<1>(workspace_sizes)), stream); + + int32_t max_num_padded_tokens_gemm1 = workspace.total_max_padded_tokens + num_experts; + int32_t max_num_padded_tokens_gemm2 = workspace.total_max_padded_tokens; + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = cache.gemm1_output.ensure( + static_cast(max_num_padded_tokens_gemm1) * 2 * intermediate_size, stream); + workspace.gemm1_output_scale = cache.gemm1_output_scale.ensure( + static_cast(2 * intermediate_size / 128) * max_num_padded_tokens_gemm1 * + sizeof(float), + stream); + workspace.activation_output = cache.activation_output.ensure( + static_cast(max_num_padded_tokens_gemm1) * intermediate_size, stream); + workspace.activation_output_scale = cache.activation_output_scale.ensure( + static_cast(intermediate_size / 128) * max_num_padded_tokens_gemm1 * + sizeof(float), + stream); + workspace.gemm2_output = cache.gemm2_output.ensure( + static_cast(max_num_padded_tokens_gemm2) * hidden_size * sizeof(__nv_bfloat16), + stream); + workspace.gemm2_output_scale = nullptr; + cache.fp8_runner->run(args, workspace, cache.device, stream, config_idx, /*enable_pdl=*/true); + return 0; +} + +} // namespace + +extern "C" int flashinfer_fused_moe_bf16(const void* input, const int32_t* topk_ids, + const float* topk_weights, + const void* gate_up_weights, + const void* down_weights, void* output, + int32_t num_tokens, int32_t hidden_size, + int32_t intermediate_size, int32_t num_experts, + int32_t top_k, int32_t input_dtype_code, + int32_t weight_dtype_code, int64_t stream) { + try { + return run_fused_moe_bf16(input, topk_ids, topk_weights, gate_up_weights, down_weights, + output, num_tokens, hidden_size, intermediate_size, num_experts, + top_k, input_dtype_code, weight_dtype_code, + reinterpret_cast(stream)); + } catch (const std::exception& e) { + std::fprintf(stderr, "flashinfer_fused_moe_bf16 failed: %s\n", e.what()); + return -1; + } +} + +extern "C" int flashinfer_fused_moe_fp8( + const void* input, const int32_t* topk_ids, const float* topk_weights, + const uint8_t* gate_up_weights, const float* gate_up_scales, const uint8_t* down_weights, + const float* down_scales, void* output, int32_t num_tokens, int32_t hidden_size, + int32_t intermediate_size, int32_t num_experts, int32_t top_k, int32_t input_dtype_code, + int64_t stream) { + try { + return run_fused_moe_fp8(input, topk_ids, topk_weights, gate_up_weights, gate_up_scales, + down_weights, down_scales, output, num_tokens, hidden_size, + intermediate_size, num_experts, top_k, input_dtype_code, + reinterpret_cast(stream)); + } catch (const std::exception& e) { + std::fprintf(stderr, "flashinfer_fused_moe_fp8 failed: %s\n", e.what()); + return -1; + } +} + +#else + +extern "C" int flashinfer_fused_moe_bf16( + const void*, + const int32_t*, + const float*, + const void*, + const void*, + void*, + int32_t, + int32_t, + int32_t, + int32_t, + int32_t, + int32_t, + int32_t, + int64_t) { + return -1; +} + +extern "C" int flashinfer_fused_moe_fp8( + const void*, + const int32_t*, + const float*, + const uint8_t*, + const float*, + const uint8_t*, + const float*, + void*, + int32_t, + int32_t, + int32_t, + int32_t, + int32_t, + int32_t, + int64_t) { + return -1; +} + +#endif // USE_FLASHINFER && required TRT-LLM headers diff --git a/src/kernels/src/trtllm/trtllm_batched_gemm_runner.cu b/src/kernels/src/trtllm/trtllm_batched_gemm_runner.cu new file mode 100644 index 0000000..acc6341 --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_batched_gemm_runner.cu @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include +#include + +#include "flashinfer/trtllm/batched_gemm/KernelRunner.h" +// #include "tensorrt_llm/common/assert.h" +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h" +#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h" +#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h" +#include "flashinfer/trtllm/common.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/envUtils.h" + +namespace tensorrt_llm { +namespace kernels { + +using namespace batchedGemm::batchedGemm; +using namespace batchedGemm::gemm; +using namespace batchedGemm::trtllm::gen; + +static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache; + +std::vector prioritizePredefinedConfigs( + int m, int n, int k, std::vector const& sortedIndices, + batchedGemm::batchedGemm::BatchedGemmConfig const* configs) { + // Function to bubble up the pre-determined config. + auto bubbleUpConfig = [&configs](std::vector const& sortedIndices, + auto&& pred) -> std::vector { + std::vector prioritizedIndices_; + // Copy matching configs to new vector + std::copy_if(sortedIndices.begin(), sortedIndices.end(), + std::back_inserter(prioritizedIndices_), [&configs, &pred](int idx) { + BatchedGemmConfig const& config = configs[idx]; + return (pred(config)); + }); + // Copy the rest of the configs to new vector, if not already copied + std::copy_if(sortedIndices.begin(), sortedIndices.end(), + std::back_inserter(prioritizedIndices_), [&prioritizedIndices_](int idx) { + return std::find(prioritizedIndices_.begin(), prioritizedIndices_.end(), idx) == + prioritizedIndices_.end(); + }); + return prioritizedIndices_; + }; + + // Init empty vector + std::vector prioritizedIndices; + + // + // Dummy + // + + if (n /* out_dim */ == 0 && k /* in_dim */ == 0) { + auto pred = [](BatchedGemmConfig const& config) { + BatchedGemmOptions const& options = config.mOptions; + return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256 && + options.mTileScheduler == TileScheduler::Persistent; + }; + prioritizedIndices = bubbleUpConfig(sortedIndices, pred); + } + // + // Fall back + // + else { + prioritizedIndices = sortedIndices; + } + + return prioritizedIndices; +} + +TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( + TrtllmGenBatchedGemmRunnerOptions const& options_) + : mOptions(options_) { + // Select a GEMM kernel config to use + auto const bmm = BatchedGemmInterface(); + auto const configs = bmm.getBatchedGemmConfigs(); + + mPassingConfigIndices.clear(); + + for (size_t i = 0; i < bmm.getNumBatchedGemmConfigs(); ++i) { + auto const options = configs[i].mOptions; + auto const tileSize = mOptions.transposeMmaOutput ? options.mTileN : options.mTileM; + // When we include low-latency kernels we can set transposeMmaOutput via constructor + if (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB && + options.mDtypeC == mOptions.dtypeC && options.mUseDeepSeekFp8 == mOptions.deepSeekFp8 && + options.mTransposeMmaOutput == mOptions.transposeMmaOutput && + (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct && + options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch && + tileSize == mOptions.tileSize && + options.mUseShuffledMatrix == mOptions.useShuffledMatrixA && + options.mLayoutA == mOptions.weightLayout) { + if (options.mFusedAct) { + if (options.mActType != static_cast(mOptions.actType)) { + continue; + } + } + + if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) { + mPassingConfigIndices.push_back(i); + } + } + } + + std::ostringstream error_msg; + error_msg << "No kernel found for the given options: " + << "mDtypeA: " << tg::dtypeToString(mOptions.dtypeA) + << ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB) + << ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC) + << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 + << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput + << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct + << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; + FLASHINFER_CHECK(!mPassingConfigIndices.empty(), error_msg.str()); +} + +size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( + int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, + int32_t numBatches, int32_t maxNumCtasInBatchDim, int32_t configIndex) const { + BatchedGemmData gemmData; + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = + mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = + mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + + auto bmm = BatchedGemmInterface(); + + auto const configs = bmm.getBatchedGemmConfigs(); + + auto const& config = configs[configIndex]; + + return bmm.getWorkspaceSizeInBytes(config, gemmData); +} + +void TrtllmGenBatchedGemmRunner::run( + int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, + int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b, + void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, + float const* scaleGateC, float const* ptrBias, float const* ptrAlpha, float const* ptrBeta, + float const* ptrClampLimit, void* c, void* outSfC, int32_t const* routeMap, + int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx, + int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, + CUstream stream, int device, int32_t configIndex, bool enable_pdl) { + auto bmm = BatchedGemmInterface(); + + BatchedGemmData gemmData; + + auto const configs = bmm.getBatchedGemmConfigs(); + + auto const& config = configs[configIndex]; + + FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0"); + if (!mOptions.staticBatch) { + FLASHINFER_CHECK(totalNumPaddedTokens, + "Batched GEMM with dynamic batching requires totalNumPaddedTokens"); + FLASHINFER_CHECK(ctaIdxXyToBatchIdx, + "Batched GEMM with dynamic batching requires ctaIdxXyToBatchIdx"); + FLASHINFER_CHECK(ctaIdxXyToMnLimit, + "Batched GEMM with dynamic batching requires ctaIdxXyToMnLimit"); + FLASHINFER_CHECK(numNonExitingCtas, + "Batched GEMM with dynamic batching requires numNonExitingCtas"); + } + + if (!mOptions.staticBatch && numTokens != 0) { + FLASHINFER_CHECK(maxNumCtasInBatchDim > 0, + "Batched GEMM with dynamic batching requires maxNumCtasInBatchDim > 0"); + } + + if (mOptions.routeAct) { + FLASHINFER_CHECK(routeMap, "Batched GEMM with routeAct requires routeMap"); + FLASHINFER_CHECK(numTokens > 0, "Batched GEMM with routeAct requires numTokens > 0"); + } + + // Dims + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = + mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = + mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + + // Inputs + gemmData.mInputBuffers.mPtrA = mOptions.transposeMmaOutput ? b : a; + gemmData.mInputBuffers.mPtrSfA = mOptions.transposeMmaOutput ? sfB : sfA; + gemmData.mInputBuffers.mPtrB = mOptions.transposeMmaOutput ? a : b; + gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; + gemmData.mInputBuffers.mPtrScaleC = scaleC; + gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; + gemmData.mInputBuffers.mPtrPerTokenSfA = + mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; + gemmData.mInputBuffers.mPtrPerTokenSfB = + mOptions.transposeMmaOutput ? perTokensSfA : perTokensSfB; + gemmData.mInputBuffers.mPtrBias = ptrBias; + gemmData.mInputBuffers.mPtrGatedActAlpha = ptrAlpha; + gemmData.mInputBuffers.mPtrGatedActBeta = ptrBeta; + gemmData.mInputBuffers.mPtrClampLimit = ptrClampLimit; + + gemmData.mInputBuffers.mPtrRouteMap = routeMap; + + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + + // Pointer to total number of padded tokens + gemmData.mInputBuffers.mPtrTotalNumPaddedTokens = totalNumPaddedTokens; + gemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + gemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + gemmData.mInputBuffers.mPtrNumNonExitingCtas = numNonExitingCtas; + + // Outputs + gemmData.mOutputBuffers.mPtrC = c; + gemmData.mOutputBuffers.mPtrSfC = outSfC; + + int32_t multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + + // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere + bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); + + auto const err = bmm.run(config, workspace, gemmData, static_cast(stream), + multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache); + + FLASHINFER_CHECK(err == 0, + "Error occurred when running GEMM!" + " (numBatches: ", + numBatches, ", GemmMNK: ", m, " ", n, " ", k, ", Kernel: ", config.mFunctionName, + ")"); +} + +void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, void const* a, + void const* sfA, void const* b, void const* sfB, void* c, + void* outSfC, void* workspace, CUstream stream, int device, + int32_t configIndex, bool enable_pdl) { + // Dispatch with block scaling factors and with static batching. + run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, + a, sfA, b, sfB, + /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, + /* scaleC */ nullptr, /* scaleGateC */ nullptr, /* ptrBias */ nullptr, /* ptrAlpha */ nullptr, + /* ptrBeta */ nullptr, /* ptrClampLimit */ nullptr, c, outSfC, + /* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr, + /* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr, + /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex, enable_pdl); +} + +void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, void const* a, + void const* sfA, void const* b, void const* sfB, + float const* ptrBias, float const* ptrAlpha, + float const* ptrBeta, float const* ptrClampLimit, void* c, + void* outSfC, void* workspace, CUstream stream, int device, + int32_t configIndex, bool enable_pdl) { + // Dispatch with block scaling factors and with static batching. + run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, + a, sfA, b, sfB, + /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, + /* scaleC */ nullptr, /* scaleGateC */ nullptr, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, c, + outSfC, + /* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr, + /* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr, + /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex, enable_pdl); +} + +void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, void const* a, + void const* b, float const* scaleC, float const* scaleGateC, + void* c, void* workspace, CUstream stream, int device, + int32_t configIndex, bool enable_pdl) { + // Dispatch with block scaling factors and with static batching. + run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, + a, + /* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, + /* perTokensSfB */ nullptr, scaleC, scaleGateC, /* ptrBias */ nullptr, /* ptrAlpha */ nullptr, + /* ptrBeta */ nullptr, /* ptrClampLimit */ nullptr, c, + /* outSfC */ nullptr, + /* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr, + /* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr, + /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex, enable_pdl); +} + +std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( + int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, + int32_t numBatches, int32_t maxNumCtasInBatchDim) const { + auto const bmm = BatchedGemmInterface(); + auto const configs = bmm.getBatchedGemmConfigs(); + + int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); + + BatchedGemmData gemmData; + // Dims + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = + mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = + mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + + auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) { + auto const& optionsA = configs[idx0].mOptions; + auto const& optionsB = configs[idx1].mOptions; + int32_t sizeK = gemmData.mProblemDimensions.mK; + + // Tier 0: K < tileK, prefer higher efficiency. + if (optionsA.mTileK != optionsB.mTileK) { + // Both waste computation, prefer higher efficiency. + if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK) { + double eff_a = (double)sizeK / optionsA.mTileK; + double eff_b = (double)sizeK / optionsB.mTileK; + return eff_a > eff_b; + } + // If either can be utilized, sort by tileK. + else { + return optionsA.mTileK > optionsB.mTileK; + } + } + + // Tier 1: When tileK is the same, prefer unroll loop 2x for mma. + if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma) { + return optionsA.mUseUnrollLoop2xForMma; + } + + // Tier 2+: When previous comparators are the same, prefer higher tileM. + if (optionsA.mTileM != optionsB.mTileM) { + return optionsA.mTileM > optionsB.mTileM; + } + + // Tier 2+: When previous comparators are the same, prefer higher tileN. + if (optionsA.mTileN != optionsB.mTileN) { + return optionsA.mTileN > optionsB.mTileN; + } + + // Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on + // the larger side, prefer persistent tile scheduler. + if (optionsA.mTileScheduler != optionsB.mTileScheduler) { + auto options = bmm.getOptionsFromConfigAndData(configs[idx0], gemmData); + auto numCtas = bmm.getNumCtas(options, gemmData.mProblemDimensions.mMaxNumCtasInTokenDim); + if (numCtas > multiProcessorCount) { + return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent; + } else { + return optionsB.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent; + } + } + + return false; + }; + + // Sort configs by options. + std::vector sortedIndices = mPassingConfigIndices; + std::sort(sortedIndices.begin(), sortedIndices.end(), cmpFunc); + + // Special rules for corner cases, if applicable. + std::vector prioritizedIndices = + prioritizePredefinedConfigs(m, n, k, sortedIndices, configs); + + // Filter out invalid configs. + std::vector validConfigIndices; + for (auto const& configIndex : prioritizedIndices) { + auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData); + if (isValidConfig) { + validConfigIndices.push_back(configIndex); + } + } + + FLASHINFER_CHECK(!validConfigIndices.empty(), + "No valid config found for the given problem shape"); + + return validConfigIndices; +} + +int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex( + int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, + int32_t numBatches, int32_t maxNumCtasInBatchDim) const { + auto const validConfigIndices = + getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim); + + return validConfigIndices[0]; +} + +bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, + int32_t k, + std::vector const& batchedTokens, + int32_t numTokens, int32_t numBatches, + int32_t maxNumCtasInBatchDim) const { + auto const bmm = BatchedGemmInterface(); + auto const configs = bmm.getBatchedGemmConfigs(); + + BatchedGemmData gemmData; + // Dims + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = + mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = + mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + + auto const& config = configs[configIndex]; + + // FIXME: temporarily disable split-k as renormalize routing plus expert number 256 failed in + // trtllm-gen ac83afb + return bmm.isValidConfig(config, gemmData) && config.mOptions.mClusterDimZ == 1; +} + +} // namespace kernels +} // namespace tensorrt_llm + +#endif // USE_FLASHINFER diff --git a/src/kernels/src/trtllm/trtllm_cutlass_heuristic.cpp b/src/kernels/src/trtllm/trtllm_cutlass_heuristic.cpp new file mode 100644 index 0000000..e799937 --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_cutlass_heuristic.cpp @@ -0,0 +1,776 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" + +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" +#include "tensorrt_llm/common/assert.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC + +#include + +#include +#include + +using namespace tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm { +namespace kernels { +namespace cutlass_kernels { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + return TileShape{16, 256}; + default: + TLLM_THROW("[get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, + TileShape const tile_shape, int const split_k_factor, + size_t const workspace_bytes, bool const is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = + split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles( + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + enum class CutlassGemmType : char { Default, WeightOnly, Simt, Int8, Fp8 }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { + gemm_type = CutlassGemmType::Simt; + } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { + gemm_type = CutlassGemmType::Int8; + } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { + gemm_type = CutlassGemmType::Fp8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } else { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::Fp8: + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { + if (sm == 89 || sm == 120 || sm == 121) { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; + } else { + // no valid ampere style fp8 configs for sm90 + return {}; + } + } else { + if (sm == 89 || sm >= 120) { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x128_WarpShape64x32x128, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; + } else { + return {}; + } + } + default: + return base_configs; + } +} + +std::vector get_candidate_tiles_sm90( + CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + if (config & CutlassGemmConfig::WEIGHT_ONLY) { + return { + CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B}; + } else { + return { + CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + } + } else { + return { + CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; + } +#endif +} + +bool sm90_supports_coop(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B, + CutlassTileConfigSM90::CtaShape256x256x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely +// to improve compilation speed. +bool sm90_supports_mcast_along_m(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely +// to improve compilation speed. +bool sm90_supports_mcast_along_n(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape64x128x128B, CutlassTileConfigSM90::CtaShape64x256x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +std::vector get_candidate_configs_sm100_dynamic_cluster_shape( + int sm, CutlassGemmConfig::CandidateConfigTypeParam const config, EpilogueScheduleType schedule, + ClusterShape const dynamic_cluster_shape, ClusterShape const fallback_cluster_shape) { + auto cluster1sm = ClusterShape::ClusterShape_1x1x1; + auto cluster2sm = ClusterShape::ClusterShape_2x1x1; + bool supports_2sm = dynamic_cluster_shape == ClusterShape::Undefined || + std::get<0>(enum_to_shape_tuple(dynamic_cluster_shape)) % 2 == 0; + + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + if (sm == 100) { + if (schedule != EpilogueScheduleType::TMA) return {}; + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule, + cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + if (supports_2sm) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + } + } + + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, + schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + if (supports_2sm) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + } + return candidate_configs; + } + + std::vector> tile_configs{ + {CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x64x128B, cluster1sm}, + }; + + if (supports_2sm) { + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x128x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x256x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x64x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x64x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x128x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x256x128B, cluster2sm}); + } + + if (config & CutlassGemmConfig::FP8_ONLY) { + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, cluster1sm}); + // TODO: re-enable when handled by the MoE GEMM dispatch + // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, + // ClusterShape::ClusterShape_1x1x1 }); + } + + for (auto [tile, cluster] : tile_configs) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, schedule, + cluster, dynamic_cluster_shape, fallback_cluster_shape, + sm}; + candidate_configs.push_back(config); + } + return candidate_configs; +} + +std::vector get_candidate_configs_sm100( + CutlassGemmConfig::CandidateConfigTypeParam const config, int sm) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, + ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, + ClusterShape::Undefined, sm}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + for (auto schedule : {EpilogueScheduleType::TMA, EpilogueScheduleType::NO_SMEM}) { + // TODO The tactic profiling is a bit long with all of these shapes enabled + // Shape 4x4x1 shapes do not seem to give better performance in the cases I tested so we + // disable it here + auto cluster_shapes = { + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_4x1x1, + ClusterShape::ClusterShape_4x2x1 /*, ClusterShape::ClusterShape_4x4x1*/}; + for (auto cluster_shape : cluster_shapes) { + auto fallback_cluster_shape = cluster_shape == ClusterShape::ClusterShape_1x1x1 + ? ClusterShape::ClusterShape_1x1x1 + : ClusterShape::ClusterShape_2x1x1; + auto configs = get_candidate_configs_sm100_dynamic_cluster_shape( + sm, config, schedule, cluster_shape, fallback_cluster_shape); + candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end()); + } + + auto configs = get_candidate_configs_sm100_dynamic_cluster_shape( + sm, config, schedule, ClusterShape::Undefined, ClusterShape::Undefined); + candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end()); + } + return candidate_configs; + } else { + TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif +} + +std::vector get_candidate_configs_sm90( + CutlassGemmConfig::CandidateConfigTypeParam const config) { + auto tiles = get_candidate_tiles_sm90(config); + std::vector candidate_configs; + for (auto const& tile_config : tiles) { + bool const has_m_mcast = sm90_supports_mcast_along_m(tile_config); + bool const has_n_mcast = sm90_supports_mcast_along_n(tile_config); + bool const has_w4afp8 = + (config & CutlassGemmConfig::WEIGHT_ONLY) && (config & CutlassGemmConfig::GROUPED_GEMM); + if (has_w4afp8) { + bool const has_coop_supported = sm90_supports_coop(tile_config); + std::set mainloop_schedules{MainloopScheduleType::PINGPONG}; + if (has_coop_supported) { + mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE); + } + auto const epilogue_schedule = EpilogueScheduleType::AUTO; + for (auto const& mainloop_schedule : mainloop_schedules) { + CutlassGemmConfig candidate(tile_config, mainloop_schedule, epilogue_schedule, + ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(candidate); + candidate = CutlassGemmConfig(tile_config, mainloop_schedule, epilogue_schedule, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(candidate); + candidate = CutlassGemmConfig(tile_config, mainloop_schedule, epilogue_schedule, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(candidate); + candidate = CutlassGemmConfig(tile_config, mainloop_schedule, epilogue_schedule, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(candidate); + } + } else { + CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(candidate); + if (has_m_mcast) { + CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(candidate); + } + + if (has_n_mcast) { + CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(candidate); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(candidate); + } + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0 && !(config & CutlassGemmConfig::GROUPED_GEMM)) { + CutlassGemmConfig CudaKernelConfig(tiles[0], MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +/*std::vector get_candidate_configs_sm100( + CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape256x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape256x256x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } + + for (int cluster_m = 1; cluster_m <= 2; cluster_m++) { + bool Is2SM = cluster_m == 2; + for (int cluster_n = 1; cluster_n <= 2; cluster_n++) { + std::vector base = {// M=128 + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B}; + + if (Is2SM) { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); + base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); + } + + std::vector twosm = {// M=256 + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; + std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); + } else { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); + if ((config & CutlassGemmConfig::FP8_ONLY) != 0) { + base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); + } + } + + std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B}; + std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); + } + + constexpr std::array cluster_shapes = { + std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, + std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}; + auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; + for (auto tile : base) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + cluster}; + candidate_configs.push_back(config); + } + } + } + return candidate_configs; + } else { + TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif +}*/ + +std::vector get_candidate_configs_sm110( + CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, + ClusterShape::ClusterShape_1x1x1}}; +#else + std::vector candidate_configs; + for (int cluster_m = 1; cluster_m <= 2; cluster_m++) { + bool Is2SM = cluster_m == 2; + for (int cluster_n = 1; cluster_n <= 2; cluster_n++) { + std::vector base = {// M=128 + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B}; + + if (Is2SM) { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); + base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); + } + + std::vector twosm = {// M=256 + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; + std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); + } else { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); + if ((config & CutlassGemmConfig::FP8_ONLY) != 0) { + base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); + } + } + + std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B}; + std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); + } + + constexpr std::array cluster_shapes = { + std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, + std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}; + auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; + for (auto tile : base) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, + cluster}; + candidate_configs.push_back(config); + } + } + } + return candidate_configs; +#endif +} + +std::vector get_candidate_configs_sm120( + CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this + if (config & CutlassGemmConfig::GROUPED_GEMM) { + return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1}}; + } else { + return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x256B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1}}; + } +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM120::CtaShape128x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x256x64B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM120::CtaShape256x128x64B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } else { + TLLM_THROW("Not Implemented: SM120 group GEMM only supports nvfp4."); + } + } else { + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM120::CtaShape128x128x256B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM120::CtaShape256x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } else { + TLLM_THROW("Not Implemented: SM120 GEMM only supports nvfp4."); + } + } +#endif + +} // namespace kernels + +std::vector get_candidate_configs( + int sm, int const max_split_k, + CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + if ((config_type_param & CutlassGemmConfig::FP4_ONLY) && + !(config_type_param & CutlassGemmConfig::BLACKWELL)) { + // FP4 is only supported on blackwell + return {}; + } + + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { + return get_candidate_configs_sm90(config_type_param); + } + if (sm == 110 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { + return get_candidate_configs_sm110(config_type_param); + } + if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { + return get_candidate_configs_sm100(config_type_param, sm); + } + if (sm >= 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { + return get_candidate_configs_sm120(config_type_param); + } + + std::vector tiles = get_candidate_tiles(sm, config_type_param); + + std::vector candidate_configs; + + bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = (sm == 89) ? 3 : int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { + auto config = + CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig(tiles[0], SplitKStyle::NO_SPLIT_K, 1, min_stages); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, std::vector const& occupancies, + int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, + int const is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + TLLM_THROW( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config_sm80); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config_sm80 != CutlassTileConfig::ChooseWithHeuristic && + m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } + + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, + is_weight_only)) { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + float const current_score = float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = + split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig(candidate_config.tile_config_sm80, split_style, + split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || + split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = + split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig(candidate_config.tile_config_sm80, split_style, + split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config_sm80 == CutlassTileConfig::ChooseWithHeuristic) { + TLLM_THROW("Heuristic failed to find a valid config."); + } + + return best_config; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm + +#endif // USE_FLASHINFER diff --git a/src/kernels/src/trtllm/trtllm_fused_moe_dev_kernel.cu b/src/kernels/src/trtllm/trtllm_fused_moe_dev_kernel.cu new file mode 100644 index 0000000..fce38f2 --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_fused_moe_dev_kernel.cu @@ -0,0 +1,1011 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/fused_moe/DevKernel.h" +#include "flashinfer/utils.cuh" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Helper function for array conversion +template +__host__ __device__ constexpr static U arrayConvert(T const& input) { + cutlass::NumericArrayConverter converter; + return converter(input); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace moe::dev { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace activation { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace tg = batchedGemm::trtllm::gen; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float silu(float x) { return x / (1.0f + expf(-x)); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void activationKernel(KernelParams params) { + using Type = typename KernelParams::Type; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // immediately trigger the secondary kernel when using PDL, then wait on primary + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + for (int tokenIdx = blockIdx.z; tokenIdx < params.numTokens; tokenIdx += gridDim.z) { + // Look over experts per token + for (int k = blockIdx.y; k < params.topK; k += gridDim.y) { + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + if (permutedIdx == -1) continue; + + // Loop over hidden dim + for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; + hiddenIdx += blockDim.x * gridDim.x) { + int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; + + float x1 = (float)params.inPtr[baseIdx]; + float x2 = (float)params.inPtr[baseIdx + params.innerDim / 2]; + + float act = silu(x2); + Type out = (Type)(act * x1); + + int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; + params.outPtr[outIdx] = out; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float4Max { + __device__ __forceinline__ float4 operator()(float4 const& a, float4 const& b) const { + float4 result; + result.x = fmaxf(a.x, b.x); + result.y = fmaxf(a.y, b.y); + result.z = fmaxf(a.z, b.z); + result.w = fmaxf(a.w, b.w); + return result; + } +}; + +struct Float2Max { + __device__ __forceinline__ float2 operator()(float2 const& a, float2 const& b) const { + float2 result; + result.x = fmaxf(a.x, b.x); + result.y = fmaxf(a.y, b.y); + return result; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ VecType packedTypeFromArray(float data[size]) { + return {}; +} + +template <> +__device__ __forceinline__ float4 packedTypeFromArray(float data[4]) { + float4 result; + result.x = data[0]; + result.y = data[1]; + result.z = data[2]; + result.w = data[3]; + return result; +} + +template <> +__device__ __forceinline__ float2 packedTypeFromArray(float data[2]) { + float2 result; + result.x = data[0]; + result.y = data[1]; + return result; +} + +template <> +__device__ __forceinline__ float packedTypeFromArray(float data[1]) { + return data[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ cutlass::Array arrayFromPackedType(PackedType data) { + return cutlass::Array{}; +} + +template <> +__device__ __forceinline__ cutlass::Array arrayFromPackedType(float4 data) { + return cutlass::Array{data.x, data.y, data.z, data.w}; +} + +template <> +__device__ __forceinline__ cutlass::Array arrayFromPackedType(float2 data) { + return cutlass::Array{data.x, data.y}; +} + +template <> +__device__ __forceinline__ cutlass::Array arrayFromPackedType(float data) { + return cutlass::Array{data}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct KernelTraits; + +template <> +struct KernelTraits<4> { + using MaxOp = Float4Max; + using PackedType = float4; +}; + +template <> +struct KernelTraits<2> { + using MaxOp = Float2Max; + using PackedType = float2; +}; + +template <> +struct KernelTraits<1> { +#if CUDA_VERSION >= 12090 + using MaxOp = cuda::maximum<>; +#else + using MaxOp = cub::Max; +#endif + using PackedType = float; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128; + +template +__global__ void activationDeepSeekKernel(KernelParams params) { + using Type = typename KernelParams::Type; + int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta; + using KernelTraits = KernelTraits; + using MaxOp = typename KernelTraits::MaxOp; + using PackedType = typename KernelTraits::PackedType; + using BlockReduce = cub::BlockReduce; + + __shared__ float s_scaleOutArr[NumTokensPerCta]; + __shared__ typename BlockReduce::TempStorage tempStorage; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // immediately trigger the secondary kernel when using PDL, then wait on primary + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + // The largest (finite) value that can be represented using E4m3. + float constexpr E4m3MaxVal{448.f}; + + int const totalNumPaddedTokens = params.totalNumPaddedTokens[0]; + // Loop over tokens + float scale1Arr[NumTokensPerCta]; + float scale2Arr[NumTokensPerCta]; + float dataX1Arr[NumTokensPerCta]; + float dataX2Arr[NumTokensPerCta]; + float outArr[NumTokensPerCta]; + float absOutArr[NumTokensPerCta]; + int permutedIdxArr[NumTokensPerCta]; + + // Loop over tokens + for (int k = blockIdx.z; k < params.topK; k += gridDim.z) { + for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens; + tokenCtaIdx += gridDim.y * NumTokensPerCta) { + for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; + hiddenIdx += blockDim.x * gridDim.x) { +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + scale1Arr[tokenInCtaIdx] = 0.0f; + scale2Arr[tokenInCtaIdx] = 0.0f; + dataX1Arr[tokenInCtaIdx] = 0.0f; + dataX2Arr[tokenInCtaIdx] = 0.0f; + outArr[tokenInCtaIdx] = 0.0f; + absOutArr[tokenInCtaIdx] = 0.0f; + } +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; + if (tokenIdx >= params.numTokens) { + break; + } + + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + permutedIdxArr[tokenInCtaIdx] = permutedIdx; + if (permutedIdx == -1) { + continue; + } + + // Process blocks for this CTA + int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; + + int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); + int const scale2Idx = permutedIdx + totalNumPaddedTokens * + ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); + + scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx]; + scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx]; + dataX1Arr[tokenInCtaIdx] = static_cast(params.inPtr[baseIdx]); + dataX2Arr[tokenInCtaIdx] = + static_cast(params.inPtr[baseIdx + params.innerDim / 2]); + } + +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx]; + float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx]; + float act = silu(x2); + float out = act * x1; + outArr[tokenInCtaIdx] = out; + absOutArr[tokenInCtaIdx] = fabsf(out); + } + + auto absOutPacked = packedTypeFromArray(absOutArr); + auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{}); + auto aMaxArr = arrayFromPackedType(aMaxPacked); + +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + if (threadIdx.x == 0) { + auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; + if (tokenIdx >= params.numTokens) { + break; + } + int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; + if (permutedIdx == -1) { + continue; + } + // Make sure the scale is strictly positive to avoid division by zero in case the + // maximum is zero. + float scaleOut = + fmaxf(aMaxArr[tokenInCtaIdx] / E4m3MaxVal, std::numeric_limits::min()); + s_scaleOutArr[tokenInCtaIdx] = scaleOut; + int const scaleOut_idx = + permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); + params.outDqSfsPtr[scaleOut_idx] = scaleOut; + } + } + __syncthreads(); + +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; + if (tokenIdx >= params.numTokens) { + break; + } + int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; + if (permutedIdx == -1) { + continue; + } + float const scaleOut = s_scaleOutArr[tokenInCtaIdx]; + int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; + params.outPtr[outIdx] = static_cast(outArr[tokenInCtaIdx] / scaleOut); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data const& data, void* stream) { + if (data.mDtypeElt == tg::Dtype::E2m1) { + // Note: this should be unreachable because the options are checked beforehand. + // E2m1 requires using higher-precision intermediate data (bf16). + FLASHINFER_CHECK(false, "Activation with E2m1_t isn't supported."); + return; + } + + if (data.mUseDeepSeekFp8) { + constexpr int NUM_ELTS_PER_LOAD = 1; + constexpr int NUM_ELTS_PER_SF = 128; + + int device{-1}; + cudaGetDevice(&device); + int numSms = 0; + cudaDeviceGetAttribute(&numSms, cudaDevAttrMultiProcessorCount, device); + + // Output dimension is innerDim / 2, and each scale block is 128 elements + int const outputDim = data.innerDim / 2; + int const numScaleBlocks = (outputDim + NUM_ELTS_PER_SF - 1) / NUM_ELTS_PER_SF; + int const gridSizeX = (numScaleBlocks + NUM_ELTS_PER_LOAD - 1) / NUM_ELTS_PER_LOAD; + + auto numCtas = gridSizeX * data.numTokens * data.topK; + // FIXME: This is heruistic based on very short benchmark. + int numTokensPerCta = 1; + if (numCtas > numSms * 32) { + numTokensPerCta = 4; + } else if (numCtas > numSms * 4) { + numTokensPerCta = 2; + } else { + numTokensPerCta = 1; + } + + int const gridSizeY = std::min(8192, (data.numTokens + numTokensPerCta - 1) / numTokensPerCta); + + const dim3 grid(gridSizeX, gridSizeY, data.topK); + + LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, + DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream); + } else { + int const numThreads = 256; + const dim3 grid(data.innerDim / 128, data.topK, data.numTokens); + + LAUNCH_ACTIVATION(data, activationKernel, 1, grid, numThreads, 0, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace activation + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace convertsf { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace tg = batchedGemm::trtllm::gen; + +namespace dev { +// Compute the offset that corresponds to (dataRowIdx, dataBlkColIdx) in the SF tensor where +// dataRowIdx and dataBlkColIdx are the respective indices of the row and the block of 16 elts +// from the K dim in the tensor of data. +inline __device__ int64_t getSfOffset(int32_t dataRowIdx, int32_t dataBlkColIdx, + int32_t numDataBlksPerRow) { + // The number of rows of SF per block. + static int32_t constexpr NumRowsPerSfBlock = 128; + // The number of cols of SF per block. + static int32_t constexpr NumColsPerSfBlock = 4; + // The size of each SF block. + static int32_t constexpr NumBytesPerSfBlock = NumRowsPerSfBlock * NumColsPerSfBlock; + + // The number of rows of data per SF block. + static int32_t constexpr NumDataRowsPerSfBlock = NumRowsPerSfBlock; + // The number of cols of blocks of data per SF block. + static int32_t constexpr NumDataBlkColsPerSfBlock = NumColsPerSfBlock; + + // The row of the SF block in the SF tensor. + int sfBlkRowIdx = dataRowIdx / NumDataRowsPerSfBlock; + // The col of the SF block in the SF tensor. + int sfBlkColIdx = dataBlkColIdx / NumDataBlkColsPerSfBlock; + // The blocks are stored row-major in the tensor of scaling factors. + int sfBlkIdx = sfBlkRowIdx * numDataBlksPerRow / NumDataBlkColsPerSfBlock + sfBlkColIdx; + + // Find the row in the SF block. + int sfRowIdx = (dataRowIdx % 32) * 4 + (dataRowIdx % NumDataRowsPerSfBlock) / 32; + // Find the col in the SF block. + int sfColIdx = (dataBlkColIdx % 4); + + // Compute the offset in bytes. + return sfBlkIdx * NumBytesPerSfBlock + sfRowIdx * NumColsPerSfBlock + sfColIdx; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Given the GMEM address of an output element, compute the offset of the corresponding scaling +// factor in the SF tensor. Optionally, a startTokenIndex can be provided if the first token is not +// the start token in the SF tensor. This is useful when inflight batching is enabled in TRT-LLM, +// where the context and generation output are stored as one output tensor. In this case, the +// generation output may not start with zero offset in the SF output tensor. +template +inline __device__ int64_t getSfOffset(int64_t gmemOffsetInBytes, int32_t hiddenDim, + int32_t startTokenIdx = 0) { + // The number of elements per sf. + int32_t constexpr NumEltsPerSf = 16; + // The GMEM offset of the output element. + int64_t gmemOffset = gmemOffsetInBytes * 8 /*bits*/ / NumBitsPerElt; + // The row/col indices of the corresponding SF element. + int32_t sfRowIdx = gmemOffset / hiddenDim + startTokenIdx; + int32_t sfColIdx = (gmemOffset % hiddenDim) / NumEltsPerSf; + // Compute the SF offset. + return getSfOffset(sfRowIdx, sfColIdx, hiddenDim / NumEltsPerSf); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// TODO(tizheng): Refactor to track gmem offset instead of doing pointer subtraction. +template +inline __device__ int64_t getSfOffset(void const* gmemOutPtr, void const* gmemBasePtr, + int32_t hiddenDim, int32_t startTokenIdx = 0) { + return getSfOffset( + reinterpret_cast(gmemOutPtr) - reinterpret_cast(gmemBasePtr), + hiddenDim, startTokenIdx); +} + +} // namespace dev + +// TODO: it would be nice to move some of that logic to Fp4Utils.h +template +inline __device__ int32_t getSfOffset(int32_t dataRowIdx, int32_t dataBlkColIdx, + int32_t numDataBlksPerRow) { + if constexpr (Layout == tg::SfLayout::Linear) { + return numDataBlksPerRow * dataRowIdx + dataBlkColIdx; + } else if constexpr (Layout == tg::SfLayout::R128c4) { + return static_cast(dev::getSfOffset(dataRowIdx, dataBlkColIdx, numDataBlksPerRow)); + } else if constexpr (Layout == tg::SfLayout::R8c4 || Layout == tg::SfLayout::R8c16) { + static int32_t constexpr NumRowsPerSfBlock = 8; + static int32_t constexpr NumColsPerSfBlock = (Layout == tg::SfLayout::R8c4) ? 4 : 16; + static int32_t constexpr NumBytesPerSfBlock = NumRowsPerSfBlock * NumColsPerSfBlock; + int sfBlkRowIdx = dataRowIdx / NumRowsPerSfBlock; + int sfBlkColIdx = dataBlkColIdx / NumColsPerSfBlock; + int sfBlkIdx = sfBlkRowIdx * numDataBlksPerRow / NumColsPerSfBlock + sfBlkColIdx; + int sfRowIdx = dataRowIdx % NumRowsPerSfBlock; + int sfColIdx = dataBlkColIdx % NumColsPerSfBlock; + return sfBlkIdx * NumBytesPerSfBlock + sfRowIdx * NumColsPerSfBlock + sfColIdx; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ void convertSfCommon(KernelParams params) { + // Note: it's assumed that the number of scaling factors per row is a multiple of 4. + constexpr int VecSize = 4; + using VecType = uint32_t; + static_assert(sizeof(VecType) == VecSize); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // Immediately trigger the secondary kernel when using PDL, then wait on primary. + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + // TODO: consider optimizing if used in production. + // This is a naive kernel. It's not doing coalesced loads. + + int const numSfPerRow = params.hiddenDimSf; + + for (int tokenIdx = blockIdx.y; tokenIdx < params.numTokens; tokenIdx += gridDim.y) { + for (int hiddenSfVecIdx = threadIdx.x + blockDim.x * blockIdx.x; + hiddenSfVecIdx < numSfPerRow / VecSize; hiddenSfVecIdx += blockDim.x * gridDim.x) { + // Index of the first SF in the vector. + int const hiddenSfIdx = VecSize * hiddenSfVecIdx; + + // Load scale factors. + int sfIdxIn = getSfOffset(tokenIdx, hiddenSfIdx, numSfPerRow); + const VecType sfVec = reinterpret_cast(params.inSfPtr)[sfIdxIn / VecSize]; + + // Store scale factors. + int const sfIdxOut = getSfOffset(tokenIdx, hiddenSfIdx, numSfPerRow); + reinterpret_cast(params.outSfPtr)[sfIdxOut / VecSize] = sfVec; + } + } +} + +#define CONVERT_FP4_SF_KERNEL(LayoutSrc, LayoutDst) \ + template \ + __global__ void convertSf##LayoutSrc##To##LayoutDst##Kernel(KernelParams params) { \ + convertSfCommon(params); \ + } +// We only need a conversion to the linear layout. +CONVERT_FP4_SF_KERNEL(R128c4, Linear); +CONVERT_FP4_SF_KERNEL(R8c4, Linear); +CONVERT_FP4_SF_KERNEL(R8c16, Linear); +#undef CONVERT_FP4_SF_KERNEL + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data const& data, void* stream) { + constexpr int VecSize = 4; + int const numThreads = 128; + int const numBlocksX = (data.hiddenDimSf / VecSize - 1 + numThreads) / numThreads; + int const numBlocksY = data.numTokens; + dim3 numBlocks(numBlocksX, numBlocksY); +#define CONVERT_FP4_SF_LAUNCH(LayoutSrc, LayoutDst) \ + if (data.sfLayoutSrc == tg::SfLayout::LayoutSrc && \ + data.sfLayoutDst == tg::SfLayout::LayoutDst) { \ + LAUNCH_PDL(data, false, cutlass::float_e4m3_t, convertSf##LayoutSrc##To##LayoutDst##Kernel, \ + numBlocks, numThreads, 0, stream); \ + return; \ + } + CONVERT_FP4_SF_LAUNCH(R128c4, Linear); + CONVERT_FP4_SF_LAUNCH(R8c4, Linear); + CONVERT_FP4_SF_LAUNCH(R8c16, Linear); +#undef CONVERT_FP4_SF_LAUNCH +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace convertsf + +namespace permute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace tg = batchedGemm::trtllm::gen; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void permuteKernel(KernelParams params) { + using Type = typename KernelParams::Type; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // immediately trigger the secondary kernel when using PDL, then wait on primary + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + for (int tokenIdx = blockIdx.y; tokenIdx < params.numTokens; tokenIdx += gridDim.y) { + // Loop over hidden dim + for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.hiddenDim; + hiddenIdx += blockDim.x * gridDim.x) { + // Load chunk of token into registers + const Type data = params.inPtr[tokenIdx * params.hiddenDim + hiddenIdx]; + + // Write to topK places + for (int k = 0; k < params.topK; k++) { + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + params.outPtr[permutedIdx * params.hiddenDim + hiddenIdx] = data; + } + } + if (params.useDeepSeekFp8) { + for (int scaleIdx = threadIdx.x + blockDim.x * blockIdx.x; scaleIdx < params.hiddenDim / 128; + scaleIdx += blockDim.x * gridDim.x) { + for (int k = 0; k < params.topK; k++) { + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + + int const idx_in = tokenIdx + params.numTokens * scaleIdx; + int const idx_out = permutedIdx + params.totalNumPaddedTokens[0] * scaleIdx; + + params.outDqSfsPtr[idx_out] = params.inDqSfsPtr[idx_in]; + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data const& data, void* stream) { + int const numThreads = 256; + int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; + int const numBlocksY = data.numTokens; + dim3 numBlocks(numBlocksX, numBlocksY); + + LAUNCH(data, permuteKernel, numBlocks, numThreads, 0, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace permute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace finalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace tg = batchedGemm::trtllm::gen; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void finalizeKernel(KernelParams params) { + using Type = typename KernelParams::Type; + using TypeExpW = typename KernelParams::TypeExpW; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // wait on primary kernel when using PDL + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + for (int tokenIdx = blockIdx.y; tokenIdx < params.numTokens; tokenIdx += gridDim.y) { + // Loop over hidden dim + for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.hiddenDim; + hiddenIdx += blockDim.x * gridDim.x) { + // Accumulate chunk of token into registers + float data = 0.0F; + + // Write to topK places + for (int k = 0; k < params.topK; k++) { + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + + if (permutedIdx == -1) { + continue; + } + + if (params.expertWeightsPtr != nullptr) { + TypeExpW const scale = params.expertWeightsPtr[expandedIdx]; + data += + float{scale} * float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]}; + } else { + data += float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]}; + } + } + + params.outPtr[tokenIdx * params.hiddenDim + hiddenIdx] = static_cast(data); + } + } +} + +constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; + +__device__ float4 vectorizedLoadPtx(float4 const* ptr) { + float4 ret; + asm volatile("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) + : "l"(ptr)); + return ret; +} + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and performs the final skip +// connection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int MaxTopK = 64; + +typedef struct __CUDA_ALIGN__(4) { + cutlass::bfloat16_t array[2]; +} bfloat16_2; + +typedef struct __CUDA_ALIGN__(8) { + cutlass::bfloat16_t array[4]; +} bfloat16_4; + +typedef struct __CUDA_ALIGN__(8) { + half array[4]; +} half_4; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScaleTraitsStruct; + +template <> +struct ScaleTraitsStruct<1, cutlass::bfloat16_t> { + using PackedType = cutlass::bfloat16_t; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, cutlass::bfloat16_t> { + using PackedType = bfloat16_2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, cutlass::bfloat16_t> { + using PackedType = bfloat16_4; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<1, float> { + using PackedType = float; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, float> { + using PackedType = float2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, float> { + using PackedType = float4; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<1, half> { + using PackedType = half; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, half> { + using PackedType = half2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, half> { + using PackedType = half_4; + using ArrayType = cutlass::Array; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FinalizeTraits; + +template +struct FinalizeTraits<1, TypeExpW_> { + using IdxPackedType = int; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<1, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<2, TypeExpW_> { + using IdxPackedType = int2; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<2, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<4, TypeExpW_> { + using IdxPackedType = int4; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<4, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void finalizeKernelVecLoad(KernelParams params) { + using Type = typename KernelParams::Type; + using TypeExpW = typename KernelParams::TypeExpW; + int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor; + + static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4, + "TopKUnrollFactor must be 1, 2, or 4"); + using FinalizeTraits = FinalizeTraits; + using IdxPackedType = typename FinalizeTraits::IdxPackedType; + using IdxArrayType = typename FinalizeTraits::IdxArrayType; + using ScalePackedType = typename FinalizeTraits::ScalePackedType; + using ScaleArrayType = typename FinalizeTraits::ScaleArrayType; + + int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits::value; + int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits::value; + assert(hiddenDimPaddedBits % 128 == 0); + assert(hiddenDimBits % 128 == 0); + + // Load 128-bits per thread, according to the smallest data type we read/write + constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; + + int64_t const tokenIdx = blockIdx.x; + int64_t const startOffset = threadIdx.x; + int64_t const stride = FINALIZE_THREADS_PER_BLOCK; + int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; + int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; + bool const useScale = params.expertWeightsPtr != nullptr; + + __shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor]; + __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor]; + + for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor; + kChunkIdx += blockDim.x) { + int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor; + auto permutedIdxPacked = reinterpret_cast( + params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor]; + auto scalePacked = useScale ? reinterpret_cast( + params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] + : ScalePackedType{TypeExpW(1.f)}; + + scaleArrSmem[kChunkIdx] = scalePacked; + permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked; + } + + auto const offset = tokenIdx * params.hiddenDim; + Type* outputPtr = params.outPtr + offset; + auto* outElemPtr = reinterpret_cast(outputPtr); + auto const* inElemPtr = reinterpret_cast(params.inPtr); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // wait on primary kernel when using PDL + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + __syncthreads(); + + for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) { + ComputeElem threadOutput; + threadOutput.fill(0); + for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) { + auto permutedIdxArr = *reinterpret_cast(&permutedIdxArrSmem[kChunkIdx]); + InputElem inputElemArr[TopKUnrollFactor]; +#pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } + + auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; + + float4 input = + vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); + inputElemArr[ki] = *reinterpret_cast(&input); + } + auto scaleArr = *reinterpret_cast(&scaleArrSmem[kChunkIdx]); + auto const scaleFloatArr = + arrayConvert>(scaleArr); + +#pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } + auto scale = useScale ? scaleFloatArr[ki] : 1.0f; + ComputeElem expertResult = arrayConvert(inputElemArr[ki]); + threadOutput = threadOutput + scale * expertResult; + } + } + OutputElem outputElem = arrayConvert(threadOutput); + outElemPtr[elemIndex] = outputElem; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void finalizeDeepSeekKernel(KernelParams params) { + using Type = typename KernelParams::Type; + using BlockReduce = cub::BlockReduce; + + __shared__ float s_scaleOut; + __shared__ typename BlockReduce::TempStorage temp_storage; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // wait on primary kernel when using PDL + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + for (int tokenIdx = blockIdx.y; tokenIdx < params.numTokens; tokenIdx += gridDim.y) { + // Loop over hidden dim + for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.hiddenDim; + hiddenIdx += blockDim.x * gridDim.x) { + // Accumulate chunk of token into registers + float acc = 0.0f; + + for (int k = 0; k < params.topK; k++) { + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + if (permutedIdx == -1) { + continue; + } + int const totalNumPaddedTokens = params.totalNumPaddedTokens[0]; + int const scaleIdx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); + float const blockScale = params.inDqSfsPtr ? params.inDqSfsPtr[scaleIdx] : 1; + + float const expertProb = (float)params.expertWeightsPtr[tokenIdx * params.topK + k]; + + float const scale = expertProb * blockScale; + acc += scale * + static_cast(params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]); + } + + // The largest (finite) value that can be represented using E4m3. + float constexpr E4m3MaxVal{448.f}; + + // Compute the absolute max +#if CUDA_VERSION >= 12090 + float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cuda::maximum<>{}); +#else + float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cub::Max{}); +#endif + + if (threadIdx.x == 0) { + if (params.outDqSfsPtr) { + s_scaleOut = aMax / E4m3MaxVal; + int const scaleOut_idx = tokenIdx + hiddenIdx / 128 * params.numTokens; + params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal; + } else { + s_scaleOut = 1.0f; + } + } + __syncthreads(); + float const scaleOut = s_scaleOut; + __syncthreads(); + params.outPtr[tokenIdx * params.hiddenDim + hiddenIdx] = (Type)(acc / scaleOut); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +void run(Data const& data, void* stream) { + if (data.mUseDeepSeekFp8) { + int const numThreads = 128; + int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; + // Capped at rather arbitrary 8192 to avoid gridDim exceeding 65535 specified by CUDA. + int const numBlocksY = std::min(8192, data.numTokens); + dim3 numBlocks(numBlocksX, numBlocksY); + + LAUNCH_TOPK_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); + } else { + int const numThreads = 256; + int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; + // Capped at rather arbitrary 8192 to avoid gridDim exceeding 65535 specified by CUDA. + int const numBlocksY = std::min(8192, data.numTokens); + + if (numBlocksX * numBlocksY < 1184) { + // The number 1184 comes from 148 * 8, where 148 is the number of SMs (Streaming + // Multiprocessors) in the Blackwell architecture, and the value 8 means that each Streaming + // Multiprocessor (SM) can hold up to 8 blocks for this kernel. This limitation is intended to + // ensure that when the number of waves is greater than 1, we choose to use the kernel with + // vectorized loading. + dim3 numBlocks(numBlocksX, numBlocksY); + LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); + } else { + FLASHINFER_CHECK( + data.topK <= MaxTopK, + "Finalize kernel with vectorized loading is not supported for this TopK value: %d", + data.topK); + LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, + /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace finalize + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace moe::dev + +#endif // USE_FLASHINFER diff --git a/src/kernels/src/trtllm/trtllm_fused_moe_routing_deepseek.cu b/src/kernels/src/trtllm/trtllm_fused_moe_routing_deepseek.cu new file mode 100644 index 0000000..a1eca70 --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_fused_moe_routing_deepseek.cu @@ -0,0 +1,665 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include + +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" + +namespace moe::dev::routing { + +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopExperts = 8; +static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxNumGroups = 8; + +template +__global__ void routingMainKernel(KernelParams params) { + // declare types + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; + // number of expert groups is bounded by number of warps + __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WarpSize; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + // warps outside the range of expert groups do not participate + if constexpr (KernelParams::UseGroups) { + if (warpIdx >= params.mNumExpertGroups) { + return; + } + } + + // note that for invalid scores, we use negative infinity, + // needed for GLM-style routing where bias can be negative + static constexpr float invalidScoreFloat = -float(INFINITY); + const OutputT invalidScore = OutputT{invalidScoreFloat}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < params.mNumExperts; + if constexpr (KernelParams::UseGroups) { + threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; + expertSelected = laneIdx < params.mNumExpertsPerGroup; + } + auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; + auto biasVal = + expertSelected ? static_cast(params.mPtrRoutingBias[threadExpert]) : invalidScoreFloat; + // initialize the mPtrExpertCounts + if (params.mPtrExpertCounts) { + int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; + int32_t globalThreadStride = gridDim.x * blockDim.x; + int32_t expertCountsNum = 2 * params.mNumExperts; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + } + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // trigger the secondary kernel when using PDL, then wait on primary + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrScores != nullptr) { + // get our assigned thread score; each warp represents one expert group + float score = + expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; + // get the sigmoid score + // note that for invalid values, we simply use a negative value: + // sigmoig scores are always strictly positive + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + // get the score with bias + // note: with invalid values, invalidScoreFloat ensures values are always smaller than valid + // ones + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[MaxNumTopExperts]; + + if constexpr (KernelParams::UseGroups) { + topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + /* minValue */ invalidScoreFloat); + // get the final group score and write it to shared + if (cute::elect_one_sync()) { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + if constexpr (KernelParams::UseGroups) { // a single warp performs the selection of top groups, + // and goes on to select the final experts + if (warpIdx == 0) { + float groupScore = + laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; + topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + // final expert selection: get relevant indexes and scores from shared + +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups + auto groupIdx = topGroupIdx[ii]; + expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; + // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. + // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, + // thus groupIdx <= params.mNumExpertGroups - 1 => + // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - + // params.mNumExpertsPerGroup + // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, + // so the access is safe here + expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) { + // without groups, each thread just takes `MaxNumTopGroups` experts + int constexpr NumExpertWarps = + (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WarpSize * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts + ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + + if (laneIdx < params.mTopK) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } + } + __syncthreads(); + if (warpIdx == 0) { + int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; + float intermidiateScore[NumInterTopKPerThread]; + int32_t intermidiateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { + int ii = i / WarpSize; + if (i < NumInterTopK) { + intermidiateScore[ii] = smemInterTopScores[i]; + intermidiateExpert[ii] = smemInterTopExperts[i]; + } else { + intermidiateScore[ii] = invalidScoreFloat; + intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; + } + } + topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } else { + if (warpIdx == 0) { + // without groups, each thread just takes `MaxNumTopGroups` experts +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = + expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + + if (warpIdx == 0) { + // determine our lane's expert index and write to output + int32_t expertIdx = 0; +#pragma unroll + for (int ii = 0; ii < params.mTopK; ++ii) { // bound of params.mTopK + expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; + } + // determine whether our expert is local to this GPU + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + + float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; + + // write expert idx out already + auto idxTopK = blockIdx.x * params.mTopK + laneIdx; + if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) { + PackedScoreIdx packedScore{static_cast(finalScore), + static_cast(expertIdx)}; + params.mPtrTopKPacked[idxTopK] = packedScore; + } + + if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && + params.mPtrTopKIds == nullptr) { + params.mPtrTopKWeights[idxTopK] = finalScore; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) + __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesClusterKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const clusterBlockRank = blockIdx.x; + + //@todo: try to move it into routingPermutation + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + routingPermutation(params, nullptr, warpIdx, clusterBlockRank); +} +#else +__global__ void routingIndicesClusterKernel(KernelParams params) { + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesCoopKernel(KernelParams params) { + // number of experts is bounded by number of threads + int constexpr NumThreads = KernelParams::MaxNumExperts; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; + // needed for the exclusive sum of token offsets + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. + static constexpr int MaxExpandedIdxPerThread = 64; + + // Initialize grid. + cg::grid_group grid = cg::this_grid(); + // Note: the following is more efficient than grid.block_index() because we don't use y and z. + int32_t const gridBlockIdx = blockIdx.x; + int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; + int32_t const numBlocks = gridDim.x; + int32_t const numThreadsPerGrid = numBlocks * NumThreads; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + + auto expandedIdxSize = params.mNumTokens * params.mTopK; + + // pre-fill the counts with 0 + smemExpertCount[threadIdx.x] = 0; + __syncthreads(); + + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + + // each thread keeps has some number of "expanded indexes" assigned to it + // for each of these, we keep the associated expert and offset within expert in registers + int32_t expertIndexes[MaxExpandedIdxPerThread]; + int32_t expertOffsets[MaxExpandedIdxPerThread]; + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a + // time, and branch between a fast path without bound checks and a slow path with bound checks. + int constexpr IterStride = 4; + static_assert(MaxExpandedIdxPerThread % IterStride == 0); + + // Define a lambda to avoid code duplication in both branches. + auto loopBody = [&](int ii, int expandedIdx) { + int32_t expertIdx = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] + : params.mPtrTopKPacked[expandedIdx].idx; + expertIndexes[ii] = expertIdx; + // check whether this expert is local to our GPU at all and ignore if not + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; + }; + +#pragma unroll + for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { + // Whether it's safe to do multiple iterations without bound checks. + bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; + if (takeFastPath) { +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + loopBody(ii, expandedIdx); + } + } else { + bool doBreak = false; +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) { + doBreak = true; + break; + } + loopBody(ii, expandedIdx); + } + if (doBreak) { + break; + } + } + } + + // Make histogram (token counts per expert) available to all threads in the block. + __syncthreads(); + + // + // Each thread now represents one expert + // + + // Add the local bin count to the common bin count and get a per-CTA offset. + int32_t const localExpertCount = smemExpertCount[threadIdx.x]; + + int32_t blockExpertOffset = 0; + if (threadIdx.x < params.mNumExperts) { + blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); + } + + // Sync to wait for completion of the histogram reduction. + grid.sync(); + + // Get total count for this expert. + int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; + + // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. + + // Compute the runtime config for projections + // Whether or not an expert is local is taken into account when smemExpertCount is computed + // so we do not need to take it into account here. + + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + + int32_t ctaOffset; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) { + const int32_t localExpertIdx = + (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); + } + + // get the padded offset associated with this expert + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + + // write out padded count + if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + + // write expert offsets to shared + smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; + + // make expert offsets available to all threads + __syncthreads(); + + // trigger the secondary kernel when using PDL + // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, + // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens + // TODO: this is not sufficient to ensure visibility in the next kernel! + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } + +// each thread has the same "expanded indexes" assigned to it as above +// at this point, we know the final offsets of experts and the offsets within +// experts, which allows writing the final index values +#pragma unroll + for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) { + break; + } + auto expertIdx = expertIndexes[ii]; + // check whether this expert is local to our GPU at all + auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + auto tokenIdx = expandedIdx / params.mTopK; + auto permutedIdx = + isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } +} +#else +__global__ void routingIndicesCoopKernel(KernelParams params) { + assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); +} +#endif + +int constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else if (numExperts <= NumDeepseekExperts) { + return NumDeepseekExperts; + } else if (numExperts <= NumKimiK2Experts) { + return NumKimiK2Experts; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, topk::MaxNumExpertsUnit); \ + } else if (data.mNumExperts <= NumDeepseekExperts) { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, NumDeepseekExperts); \ + } else if (data.mNumExperts <= NumKimiK2Experts) { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, NumKimiK2Experts); \ + } else { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +void runImpl(Data& data, void* stream) { + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "DeepSeek routing."); + } + if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) + FLASHINFER_CHECK( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, + "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); + FLASHINFER_CHECK(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); + FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, + data.mNumLimitedGroups); + FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, + data.mTopK); + FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", + data.mTopK); + FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, + "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", + data.mTopK, data.mNumLimitedGroups); + FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, + "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, + data.mNumExperts); + FLASHINFER_CHECK(data.mNumExperts <= NumKimiK2Experts, + "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, + NumKimiK2Experts); + FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, + "Routing kernel expects top groups %d to be limited by #expert groups %d", + data.mNumLimitedGroups, data.mNumExpertGroups); + if (data.mNumExpertGroups > 1) { + FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups, + "Routing kernel expects #experts groups %d to be <= #warps %d", + data.mNumExpertGroups, MaxNumGroups); + FLASHINFER_CHECK(data.mNumExperts % data.mNumExpertGroups == 0, + "Routing kernel expects #experts %d to be a multiple of #expert groups %d", + data.mNumExperts, data.mNumExpertGroups); + FLASHINFER_CHECK( + data.mNumExperts / data.mNumExpertGroups <= WarpSize, + "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", + data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); + } else { + FLASHINFER_CHECK(data.mTopK <= topk::MaxNumTopK, + "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, + topk::MaxNumTopK); + } + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + + int const numBlocks = data.mNumTokens; + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + + bool const useSingleCluster = data.mNumTokens <= 1024; + if (!useSingleCluster) { + // Reset the global histograms (not used in single-cluster code path). + // Cover both for the cooperative and two-kernel code paths. + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); + } else { + data.mPtrExpertCounts = + nullptr; // Set it to nullptr for single-cluster code path, as it won't be used + } + + // Number of blocks we can use in the cooperative kernel + // The number of blocks must be: + // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ + // <= numSms, assuming an occupancy of 1 block/SM + // + // If too small for the given numTokens, fall back to the less performant two-step method. + // + // The upper bound is a strict requirement. The number of blocks should be determined by querying + // the device properties, or conservatively low. + // /!\ The following number is not portable!! (but works on H100 and B200) + int const numBlocksCoop = 128; + + // Maximum number of tokens supported by the kernel using a cooperative launch. + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + if (data.mPtrTopKIds == nullptr) { + int const numThreadsMain = + data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + } else { + // Reset the global histograms. + LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + } + + if (data.mPtrPermutedIdxSize != nullptr) { + if (useSingleCluster) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, + NumBlocksPerCluster, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + } else if (data.mNumTokens <= maxTokensCoop) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + } else { + const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; + const int32_t histogramEltsPerBlock = 8 * numThreadsHist; + const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + // Limit grid size (both kernels use a grid-stride loop). + const int32_t maxNumBlocks = 1024; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = + std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, + numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data& data, void* stream) { runImpl(data, stream); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing + +#endif // USE_FLASHINFER diff --git a/src/kernels/src/trtllm/trtllm_fused_moe_routing_llama4.cu b/src/kernels/src/trtllm/trtllm_fused_moe_routing_llama4.cu new file mode 100644 index 0000000..01e5240 --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_fused_moe_routing_llama4.cu @@ -0,0 +1,581 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" + +namespace moe::dev::routing { +namespace routingLlama4 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int NumThreads = 1024; +static constexpr int NumWarps = NumThreads / WarpSize; +static constexpr int MaxNumTopExperts = 1; +static constexpr int NumExpertsLimit = 128; +static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; +static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; +static constexpr int WarpKernelSmemStride = 33; +// with further optimization to `routingIndicesWarpKernel`, this limit may +// increase. For now, it is a good cut-off point for when the block-wise +// operations are more efficient end-to-end. +static constexpr int WarpKernelMaxNumTokens = 4; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile const& warp, + DataType (&warpMaxScore)[MaxNumTopExperts], + int32_t (&warpMaxExpertIdx)[MaxNumTopExperts], + int32_t const laneIdx, int32_t const numExperts, + DataType const* ptrScores) { + DataType minScore = DataType{-INFINITY}; + DataType maxScore = minScore; + int32_t maxExpertIdx{0}; + using DataTypeVec = std::conditional_t; + + // Non-vectorized loading: directly access ptrScores with expertIdx + for (int i = 0; i < VecSize; ++i) { + auto expertIdx = i * WarpSize + laneIdx; + auto newScore = expertIdx < numExperts ? ptrScores[expertIdx] : minScore; + // note: use `>=` s.t. highest index always wins, just like in `reduceTopK` + if (newScore > maxScore) { + maxScore = newScore; + maxExpertIdx = expertIdx; + } + } + + topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params) { + // types used in this kernel + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using TypePacked = PackedScoreIdx; + // use the default cub warp-scan, with shfl + using Scan = cub::WarpScan; + __shared__ typename Scan::TempStorage tempStorage; + + // each thread encodes 4 experts in one `int32_t`. The assumption is that + // we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be + // smaller than that because other approaches will be more efficient for + // 127 tokens. + static constexpr int ExpertsPerThread = sizeof(int32_t); + static_assert(WarpKernelMaxNumTokens <= 127); + // this is a full table of which token is routed to which expert. + // the assumption here is that there are no more than 128 experts. + // we use a stride of 33 instead of 32 to avoid shared memory bank conflicts. + __shared__ int32_t __attribute(( + aligned(128))) smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride]; + static_assert(WarpKernelSmemStride == WarpSize + 1); + static_assert(KernelParams::MaxNumExperts / sizeof(int32_t) <= WarpSize); + + // values needed for the top-1 reduction, if required + InputT minScore = InputT{-INFINITY}; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + +#pragma unroll + for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx) { + // reset full shared memory field to 0 + smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0; + } + __syncwarp(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrScores != nullptr && params.mPtrTopKIds == nullptr) { + // if we use `mPtrScores` as input, we need to perform the top-1 reduction + // for each token, we load the scores then use `reduceTopK` for this. + // each thread works on 4 experts, so a local reduction is done before + for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx) { + auto scoreOffset = tokenIdx * params.mNumExperts; + int32_t warpMaxExpertIdx[MaxNumTopExperts]; + InputT warpMaxScore[MaxNumTopExperts]; + + // Use routingTopKExperts function instead of inline logic + routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, + threadIdx.x, params.mNumExperts, + params.mPtrScores + scoreOffset); + + if (cute::elect_one_sync()) { + // one thread updates the count linking token to chosen expert + auto expertTokenCount = 0; + setBits(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread); + smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = + expertTokenCount; + // we also compute the final score here and write it out if required + auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[tokenIdx] = finalScore; + } + } + } + } else { + // if we do not have `mPtrScores` as input, we expect that `params.mPtrTopKPacked` or + // `params.mPtrTopKIds` and `params.mPtrTopKWeights` contains the top-1 packed score and index + // already. Each thread represents a token here, and we extract the relevant score The + // assumption is that the #tokens is limited by warp-size + static_assert(WarpKernelMaxNumTokens <= WarpSize); + TypePacked scoreIdx = TypePacked{}; + if (params.mPtrTopKIds != nullptr) { + if (threadIdx.x < params.mNumTokens) { + scoreIdx = TypePacked{static_cast(params.mPtrTopKWeights[threadIdx.x]), + static_cast(params.mPtrTopKIds[threadIdx.x])}; + } + } else { + if (threadIdx.x < params.mNumTokens) { + scoreIdx = TypePacked{static_cast(params.mPtrTopKPacked[threadIdx.x].score), + static_cast(params.mPtrTopKPacked[threadIdx.x].idx)}; + if (params.mPtrTopKWeights != nullptr) { + // we also compute the final score here and write it out if required + auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; + params.mPtrTopKWeights[threadIdx.x] = finalScore; + } + } + } + + int32_t expertTokenCount = 0; + setBits(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread); + if (threadIdx.x < params.mNumTokens) { + smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount; + } + } + + // make the full table available to all threads + __syncwarp(); + + // at this point, each thread keeps a count of its 4 assigned experts in + // `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts + // in `expertOffset`. + int32_t expertCount = 0; + int32_t expertOffset[WarpKernelMaxNumTokens + 1]; +#pragma unroll + for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx) { + if (tokenIdx > params.mNumTokens) break; + // simple reduction for `expertCount`, and scan for `expertOffset` + auto expertTokenCount = + tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0; + expertOffset[tokenIdx] = expertCount; + expertCount += expertTokenCount; + } + + // at this point, we are ready for the scan across all experts to get the + // thread-wise offsets across experts + // first, we need to reduce across our 4 experts into `numCta` + int32_t numCta = 0; +#pragma unroll + for (int ii = 0; ii < ExpertsPerThread; ++ii) { + auto count = getBits(expertCount, ii); + int32_t num; + if constexpr (KernelParams::isPow2) { + num = divUpLog2(count, params.mPaddingLog2); + } else { + num = divUpTileN(count, params.mTileTokensDim); + } + numCta += num; + } + // second, we perform the exclusive sum across the warp + int32_t ctaOffset; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + // finally, we perform a scan across our local experts, starting with the + // warp-wide scan result (`ctaOffset`) + auto ctaOffsetExp = ctaOffset; +#pragma unroll + for (int ii = 0; ii < ExpertsPerThread; ++ii) { + auto count = getBits(expertCount, ii); + int32_t finalNumCta; + if constexpr (KernelParams::isPow2) { + finalNumCta = divUpLog2(count, params.mPaddingLog2); + } else { + finalNumCta = divUpTileN(count, params.mTileTokensDim); + } + auto expertIdx = threadIdx.x * ExpertsPerThread + ii; + // during the scan for expert offsets, we can already write out + // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` + for (int cta = 0; cta < finalNumCta; ++cta) { + params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffsetExp + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffsetExp, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mnLimit1, mnLimit2); + } + ctaOffsetExp += finalNumCta; + } + + // at this point, we can write out padded count from the warp-aggregate + if (cute::elect_one_sync()) { + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // we can trigger the next kernel at this point + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + + // at this point, all values for offsets are ready, except the final offsets + // within the padded index (`permutedIdx`) + // for this, we perform a scan similar to the one directly after the warp-scan: + // here, we keep the local offset for each of the thread's experts in a field + // of registers + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + int32_t finalExpertOffset[ExpertsPerThread]; + if constexpr (KernelParams::isPow2) { + finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + finalExpertOffset[0] = mulTileN(ctaOffset, params.mTileTokensDim); + } +#pragma unroll + for (int ii = 1; ii < ExpertsPerThread; ++ii) { + int32_t tmp; + if constexpr (KernelParams::isPow2) { + tmp = divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); + } else { + tmp = divUpMulTileN(getBits(expertCount, ii - 1), params.mTileTokensDim); + } + finalExpertOffset[ii] = finalExpertOffset[ii - 1] + tmp; + } + +#pragma unroll + for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx) { + // at this point, we can calculate the final index: + // we simply loop over all tokens, and all experts assigned to this thread. + // For each pair, we determine whether that token was routed to that expert + // based on whether the offset for that token changed. + // we can then easily compute the final `expertIdx` and `permutedIdx` relative + // to this token and expert, and write them out. + if (tokenIdx >= params.mNumTokens) break; + +#pragma unroll + for (int ii = 0; ii < ExpertsPerThread; ++ii) { + // determine whether the offset for this expert and token changes + auto localOffsetToken = getBits(expertOffset[tokenIdx], ii); + auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken; + // the expert index of this expert + auto expertIdx = threadIdx.x * ExpertsPerThread + ii; + auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + // the permuted index: we add the local offset relative to this expert and token + // to the global offset from the scan for this expert + auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1}; + // write out `mPtrExpandedIdxToPermutedIdx` if required + if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted) { + params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx; + } + // write out `mPtrPermutedIdxToTokenIdx` if required + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams params) { + // number of tokens/expanded idx is bounded by total number of warps + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using TypePacked = PackedScoreIdx; + __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps]; + + uint32_t const clusterBlockRank = blockIdx.x; + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + + // TODO(mjoux): expand to more tokens (possibly) + auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; + auto scoreOffset = warpTokenIdx * params.mNumExperts; + bool validToken = warpTokenIdx < params.mNumTokens; + InputT minScore = InputT{-INFINITY}; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + + if (params.mPtrTopKIds != nullptr) { + if (validToken) { + TypePacked packedScore{static_cast(params.mPtrTopKWeights[warpTokenIdx]), + static_cast(params.mPtrTopKIds[warpTokenIdx])}; + smemPackedScoreIdx[warpIdx] = packedScore; + } + } else if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + // we then exchange all token max scores, s.t. afterwards, each thread + // represents a token + InputT warpMaxScore[MaxNumTopExperts]; + int32_t warpMaxExpertIdx[MaxNumTopExperts]; + + if (validToken) { + routingTopKExperts( + warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, + params.mPtrScores + scoreOffset); + if (cute::elect_one_sync()) { + auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; + TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; + smemPackedScoreIdx[warpIdx] = packedScore; + } + } + } else { + if (validToken) { + smemPackedScoreIdx[warpIdx] = params.mPtrTopKPacked[warpTokenIdx]; + } + } + + // make packed scores available to all threads in cluster + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrTopKIds != nullptr || params.mPtrScores != nullptr) { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } else { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } +} +#else +__global__ void routingIndicesClusterKernel(KernelParams params) { + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// this kernel is needed in case we have scores as input for the histogram kernel +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesHistogramScoresKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using TypePacked = PackedScoreIdx; + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + // we assume that #experts is a multiple of 4, so VecSize must be 4. + static_assert(VecSize == 4); + + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const warpIdx = threadIdx.x / WarpSize; + int32_t const globalWarpIdx = blockIdx.x * KernelParams::MaxNumExperts / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * KernelParams::MaxNumExperts / WarpSize; + InputT minScore = InputT{-INFINITY}; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // initialize the mPtrExpertCounts + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; + int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid and trigger secondary kernel. + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + + // in this case, each warp represents a token, and we use a grid-stride loop + // over all warps/tokens + for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { + auto scoreOffset = tokenIdx * params.mNumExperts; + int32_t warpMaxExpertIdx[MaxNumTopExperts]; + InputT warpMaxScore[MaxNumTopExperts]; + + if (params.mPtrTopKIds != nullptr) { + if (laneIdx < MaxNumTopExperts) { + warpMaxExpertIdx[laneIdx] = params.mPtrTopKIds[tokenIdx]; + warpMaxScore[laneIdx] = static_cast(params.mPtrTopKWeights[tokenIdx]); + } + } else if (params.mPtrScores != nullptr) { + routingTopKExperts( + warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, + params.mPtrScores + scoreOffset); + } else { + if (laneIdx < MaxNumTopExperts) { + warpMaxExpertIdx[laneIdx] = params.mPtrTopKPacked[tokenIdx].idx; + warpMaxScore[laneIdx] = params.mPtrTopKPacked[tokenIdx].score; + } + } + + if (cute::elect_one_sync()) { + auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; + TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; + params.mPtrTopKPacked[tokenIdx] = packedScore; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +int constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void runImpl(Data const& data, void* stream) { + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK( + data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + } + FLASHINFER_CHECK( + data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); + FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, + data.mTopK); + FLASHINFER_CHECK(data.mNumExperts <= NumExpertsLimit, + "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, + NumExpertsLimit); + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + + bool const useSingleWarp = + (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || + data.mNumTokens < WarpKernelMaxNumTokens; + bool const useSingleCluster = + data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + ? MaxNumTokensSingleClusterScores + : MaxNumTokensSingleCluster); + if (!useSingleCluster) { + FLASHINFER_CHECK( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), + "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); + } + + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + if (useSingleWarp) { + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize, + /*smemSize=*/0, // No dynamic smem + stream); + } else if (useSingleCluster) { + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream); + } else { + const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK; + + const uint32_t histogramEltsPerBlock = 8 * numThreadsHist; + const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + // Limit grid size (all kernels use a grid-stride loop). + const uint32_t maxNumBlocks = 1024; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = + std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); + } else { + // Reset the global histograms. + LAUNCH_ROUTING_LLAMA4(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); + } + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); + } +} + +void run(Data const& data, void* stream) { + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + FLASHINFER_CHECK( + data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); + FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, + "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", + data.mTopK); + FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", + data.mPaddingLog2); + + runImpl(data, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingLlama4 +} // namespace moe::dev::routing + +#endif // USE_FLASHINFER diff --git a/src/kernels/src/trtllm/trtllm_fused_moe_routing_renormalize.cu b/src/kernels/src/trtllm/trtllm_fused_moe_routing_renormalize.cu new file mode 100644 index 0000000..fcee73b --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_fused_moe_routing_renormalize.cu @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int NumThreads = 1024; +static constexpr int NumWarps = NumThreads / WarpSize; +static constexpr int MaxNumTopExperts = 10; +static constexpr int NumExpertsLimit = 512; +static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; +static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; +static constexpr int BlockKernelMaxNumTokens = 4; + +template +__forceinline__ __device__ void routingTopKExperts( + cg::thread_block_tile const& warp, DataType (&score)[VecSize], + int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts], + int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, + int32_t topK, InputType const* ptrScores, bool const normTopkProb, + bool const applySoftmaxAfterTopK) { + DataType minScore = DataType{-INFINITY}; + + for (int i = 0; i < VecSize; i++) { + auto expertIdx = i * WarpSize + laneIdx; + auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; + score[i] = newScore; + idx[i] = expertIdx; + } + if constexpr (DoSoftmaxBeforeTopK) { + calcSoftmax(warp, score); + } + + // Get the top-k scores and their corresponding expert indices + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); + + // Normalize the scores + if constexpr (DoSoftmaxBeforeTopK) { + float sum = float{1.f}; + if (normTopkProb) { + sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); + sum = cg::reduce(warp, sum, cg::plus()); + } + if (laneIdx < topK) { + warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; + } + } else { + if (applySoftmaxAfterTopK) { + auto softmaxScore = + calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); + if (laneIdx < topK) { + warpTopKScore[laneIdx] = softmaxScore; + } + } + // If applySoftmaxAfterTopK is false, we keep the raw TopK values without softmax + } +} + +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesBlockKernel(KernelParams params) { + // types used in this kernel + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = std::conditional_t; + using TypePacked = PackedScoreIdx; + int constexpr MaxNumExperts = KernelParams::MaxNumExperts; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const expert = threadIdx.x; + auto scoreOffset = warpIdx * params.mNumExperts; + bool validToken = warpIdx < params.mNumTokens; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts; + __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; + __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x) { + smemOffset[i] = int8_t{-1}; + smemKIdx[i] = int8_t{-1}; + } + __syncthreads(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrTopKIds != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; + smemKIdx[offset] = static_cast(laneIdx); + } + } + } else if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + BaseType score[VecSize]; + int32_t idx[VecSize]; + + BaseType warpTopKScore[MaxNumTopExperts]; + int32_t warpTopKExpertIdx[MaxNumTopExperts]; + + BaseType minScore = BaseType{-INFINITY}; + if (validToken) { + routingTopKExperts( + warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, + params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx]; + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + OutputT{warpTopKScore[laneIdx]}; + } + } + } // end if (validToken) + } else if (params.mPtrTopKPacked != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + + static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx); + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].score); + } + } + } + } + __syncthreads(); + + // set local experts + auto localExpertIdx = expert - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < params.mNumLocalExperts && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + // Get the count of each expert and the offset for each token + int accExpertCount = 0; + + if (isLocalExpert) { + int offset = expert; + for (int j = 0; j < BlockKernelMaxNumTokens; j++) { + if (smemKIdx[offset] >= 0) { + smemOffset[offset] = static_cast(accExpertCount); + accExpertCount++; + } + offset += MaxNumExperts; + } + } + __syncthreads(); + // Get the number of CTAs and the offset for each CTA + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(accExpertCount, params.mPaddingLog2); + } else { + numCta = divUpTileN(accExpertCount, params.mTileTokensDim); + } + int32_t ctaOffset = 0; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + int32_t expertScanCounts = 0; + int32_t tmpCount; + if constexpr (KernelParams::isPow2) { + tmpCount = divUpMulLog2(accExpertCount, params.mPaddingLog2); + } else { + tmpCount = divUpMulTileN(accExpertCount, params.mTileTokensDim); + } + Scan(tempStorage).ExclusiveSum(tmpCount, expertScanCounts); + __syncthreads(); + + if (isLocalExpert) { + for (int cta = 0; cta < numCta; ++cta) { + const int32_t localExpertIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + accExpertCount; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); + } + } + + // at this point, we can write out padded count + if (threadIdx.x == 0) { + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // we can trigger the next kernel at this point + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + + for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { + int offset = tokenIdx * MaxNumExperts + threadIdx.x; + if (smemKIdx[offset] >= 0) { + int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset]; + int const offsetWithinExpert = static_cast(smemOffset[offset]); + int const offsetForExpert = expertScanCounts; + int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t{-1}; + + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + if (isLocalExpert) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + } +} + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams params) { + // number of tokens/expanded idx is bounded by total number of warps + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + + using BaseType = std::conditional_t; + using TypePacked = PackedScoreIdx; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts]; + + uint32_t const clusterBlockRank = blockIdx.x; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + + auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; + auto scoreOffset = warpTokenIdx * params.mNumExperts; + bool validToken = warpTokenIdx < params.mNumTokens; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + + if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + BaseType score[VecSize]; + int32_t idx[VecSize]; + + BaseType warpTopKScore[MaxNumTopExperts]; + int32_t warpTopKExpertIdx[MaxNumTopExperts]; + + BaseType minScore = BaseType{-INFINITY}; + if (validToken) { + routingTopKExperts( + warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, + params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] = + TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; + } + } // end if (validToken) + } + + // make packed scores available to all threads in cluster + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrScores != nullptr) { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } else { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } +} +#else +__global__ void __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams /* params */) { + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// this kernel is needed in case we have scores as input for the histogram kernel +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesHistogramScoresKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = std::conditional_t; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const warpIdx = threadIdx.x / WarpSize; + int32_t const globalWarpIdx = blockIdx.x * KernelParams::MaxNumExperts / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * KernelParams::MaxNumExperts / WarpSize; + BaseType minScore = BaseType{-INFINITY}; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + // initialize the mPtrExpertCounts + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; + int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger secondary kernel. + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + // in this case, each warp represents a token, and we use a grid-stride loop + // over all warps/tokens + BaseType allScores[VecSize]; + int32_t allExpertIdx[VecSize]; + BaseType warpTopKScore[MaxNumTopExperts]; + int32_t warpTopKExpertIdx[MaxNumTopExperts]; + for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { + auto scoreOffset = tokenIdx * params.mNumExperts; + + routingTopKExperts( + warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx, + params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), + static_cast(warpTopKExpertIdx[laneIdx])}; + params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int32_t constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else if (numExperts <= NumExpertsLimit) { + return NumExpertsLimit; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, topk::MaxNumExpertsUnit); \ + } else if (data.mNumExperts <= NumExpertsLimit) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, NumExpertsLimit); \ + } else { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +void run(Data const& data, void* stream) { + FLASHINFER_CHECK(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || + data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "Renormalize routing."); + } + FLASHINFER_CHECK( + data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Routing kernel expects permuted idx and grouped Gemm launch config buffers"); + FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= ", + MaxNumTopExperts, ", got ", data.mTopK); + FLASHINFER_CHECK(data.mNumExperts <= NumExpertsLimit, + "Routing kernel expects #experts <= ", NumExpertsLimit, ", got ", + data.mNumExperts); + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts to be a multiple of 4, got ", + data.mNumExperts); + + // FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP + bool const useSingleBlock = + data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr; + + bool const useSingleCluster = + data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + ? MaxNumTokensSingleClusterScores + : MaxNumTokensSingleCluster); + + if (!useSingleCluster && !useSingleBlock) { + FLASHINFER_CHECK(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr, + "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is required."); + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is required."); + } + uint32_t const numThreadsHist = getMaxNumExperts(data.mNumExperts); + if (useSingleBlock) { + //@TODO: For now we use the single block kernel for cases with token number no larger than 4. + // We will future tune this threshold based on the performance. + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesBlockKernel, 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + } else if (useSingleCluster) { + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + } else { + uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; + uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + // Limit grid size (all kernels use a grid-stride loop). + uint32_t const maxNumBlocks = 1024; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = + std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + } else { + // Reset the global histograms. + LAUNCH_ROUTING_RENORNALIZE(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + } + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing + +#endif // USE_FLASHINFER diff --git a/src/kernels/src/trtllm/trtllm_fused_moe_runner.cu b/src/kernels/src/trtllm/trtllm_fused_moe_runner.cu new file mode 100644 index 0000000..203dd90 --- /dev/null +++ b/src/kernels/src/trtllm/trtllm_fused_moe_runner.cu @@ -0,0 +1,566 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(USE_FLASHINFER) && __has_include("trtllm/gen/CudaRunner.h") && __has_include("tensorrt_llm/common/logger.h") + +#include + +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/batched_gemm/KernelRunner.h" +#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h" +#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h" +#include "flashinfer/trtllm/fused_moe/DevKernel.h" +#include "flashinfer/trtllm/fused_moe/RoutingKernel.h" +#include "flashinfer/trtllm/fused_moe/runner.h" + +namespace tensorrt_llm { +namespace kernels { +namespace trtllmgen_moe { + +namespace btg = batchedGemm::trtllm::gen; + +namespace Routing { +namespace { +inline int32_t computeLog2(int32_t val, std::string const& name = "") { + int32_t n = val; + int32_t out = 0; + while (n >>= 1) { + ++out; + } + if ((1 << out) != val) { + out = -1; + } + return out; +} +} // namespace + +Runner::Runner() {} + +Runner::Runner(int32_t tileTokensDim) : mTileTokensDim(tileTokensDim) {} + +void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, + int32_t topK, int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, + int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes, + int32_t* expertCountHistogram, int32_t* permutedIdxSize, + int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, + int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, + int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, + int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, + bool useRoutingScalesOnInput, bool useDeepSeekFp8, + RoutingMethodType routingMethodType, cudaStream_t stream) { + if (routingMethodType == RoutingMethodType::DeepSeekV3) { + FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); + FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); + moe::dev::routing::routingDeepSeek::Data routingData; + routingData.mDtypeExpW = + btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16 + routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32 + + routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32 + routingData.mUsePdl = true; + + // output: + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // input: + routingData.mPtrRoutingBias = routingBias; + routingData.mPtrScores = reinterpret_cast(routingLogits); + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mNumExpertGroups = nGroup; + routingData.mNumLimitedGroups = topkGroup; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + routingData.mRouteScale = routedScalingFactor; + routingData.mUseRoutingSoftmax = false; + moe::dev::routing::routingDeepSeek::run(routingData, stream); + } else if (routingMethodType == RoutingMethodType::Llama4) { + FLASHINFER_CHECK(topK == 1, "For Llama routing method, must have topK == 1"); + if (nGroup > 0 || topkGroup > 0) { + FLASHINFER_WARN("For Llama routing method, nGroup/topkGroup is ignored, got ", nGroup, "/", + topkGroup); + } + moe::dev::routing::routingLlama4::Data routingData; + routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mUsePdl = true; + + // output: + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // input: + routingData.mPtrScores = routingLogits; + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + moe::dev::routing::routingLlama4::run(routingData, stream); + } else if (routingMethodType == RoutingMethodType::Renormalize /* default */ + || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */ + || routingMethodType == RoutingMethodType::TopK /* TopK only (no softmax) */) { + moe::dev::routing::routingRenormalize::Data routingData; + + // + // Config + // + + routingData.mDtypeExpW = btg::Dtype::Bfloat16; + // routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input + routingData.mUsePdl = true; + routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; + routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; + routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize; + + routingData.mPtrScores = routingLogits; + + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingRenormalize::run(routingData, stream); + } else { + FLASHINFER_CHECK(false, "Unimplemented routing method ", + serializeMoeRoutingMethodType(routingMethodType), " of enum ", + (int)routingMethodType); + } +} +} // namespace Routing + +namespace PermuteGemm1 { + +tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( + btg::Dtype dtypeAct, btg::Dtype dtypeWeights, int32_t tileTokensDim, bool useDeepSeekFp8, + MoE::GatedActType gatedActType, bool useShuffledMatrixA, + batchedGemm::gemm::MatrixLayout weightLayout) { + if (gatedActType == MoE::GatedActType::SwiGlu || gatedActType == MoE::GatedActType::GeGlu) { + ActType actType = + (gatedActType == MoE::GatedActType::SwiGlu) ? ActType::SwiGlu : ActType::GeGlu; + tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { + // Swap A and B dtypes because transposeMmaOutput is hardcoded to true + .dtypeA = dtypeWeights, + .dtypeB = dtypeAct, + .dtypeC = dtypeAct, + .actType = actType, + .deepSeekFp8 = useDeepSeekFp8, + .fusedAct = !useDeepSeekFp8, + .routeAct = true, + .staticBatch = false, + .transposeMmaOutput = true, + .tileSize = tileTokensDim, + .epilogueTileM = useDeepSeekFp8 ? 64 : 128, + .useShuffledMatrixA = useShuffledMatrixA, + .weightLayout = weightLayout}; + return options; + } else { + FLASHINFER_CHECK(false, "Unimplemented gated act type ", + MoE::serializeGatedActType(gatedActType), " of enum ", (int)gatedActType); + } +} + +Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim, + MoE::GatedActType gatedActType, bool useShuffledMatrixA, + batchedGemm::gemm::MatrixLayout weightLayout) + : mDtypeAct(dtypeAct), + mDtypeWeights(dtypeWeights), + mTileTokensDim(tileTokensDim), + mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( + getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, gatedActType, + useShuffledMatrixA, weightLayout))) {} + +void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale, + void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar, + float* ptrBias, float* ptrAlpha, float* ptrBeta, float* ptrClampLimit, + void* output, void* outputScale, int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, int32_t numTokens, + int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, + int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, + int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, + int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, + maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, + expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, + ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, + ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, + ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl); +} + +size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, + int32_t numExperts, int32_t numTokens, + int32_t configIndex) const { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + return mRunner.getWorkspaceSizeInBytes(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, + numExperts, maxNumCtasInBatchDim, configIndex); +} + +int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, + int32_t numTokens) const { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + return mRunner.getDefaultValidConfigIndex(numTokens, 2 * intermediateSize, hiddenSize, {}, + numTokens, numExperts, maxNumCtasInBatchDim); +} + +bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, + int32_t numTokens) const { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + + auto const isValid = + mRunner.isValidConfigIndex(configIndex, numTokens, 2 * intermediateSize, hiddenSize, {}, + numTokens, numExperts, maxNumCtasInBatchDim); + + return isValid; +} + +std::vector Runner::getPassingConfigIndices() const { + return mRunner.getPassingConfigIndices(); +} +} // namespace PermuteGemm1 + +namespace Gemm2 { +tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( + btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, int32_t tileTokensDim, + bool useDeepSeekFp8, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) { + tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { + // Swap A and B dtypes because transposeMmaOutput is hardcoded to true + .dtypeA = dtypeWeights, + .dtypeB = dtypeAct, + .dtypeC = dtypeOut, + .deepSeekFp8 = useDeepSeekFp8, + .fusedAct = false, + .routeAct = false, + .staticBatch = false, + .transposeMmaOutput = true, + .tileSize = tileTokensDim, + .epilogueTileM = useDeepSeekFp8 ? 64 : 128, + .useShuffledMatrixA = useShuffledMatrixA, + .weightLayout = weightLayout}; + return options; +} + +Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, + bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrixA, + batchedGemm::gemm::MatrixLayout weightLayout) + : mDtypeAct(dtypeAct), + mDtypeWeights(dtypeWeights), + mDtypeOut(dtypeOut), + mTileTokensDim(tileTokensDim), + mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( + getOptions(dtypeAct, dtypeWeights, dtypeOut, tileTokensDim, useDeepSeekFp8, + useShuffledMatrixA, weightLayout))) {} + +void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weights, + void* weightsScale, float* outputScalesScalar, float* ptrBias, void* output, + void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, + int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, + int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, + int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, + cudaStream_t stream, int32_t configIndex, bool enable_pdl) { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + mRunner.run( + numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, + permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, + /* perTokensSfA */ nullptr, + /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, ptrBias, + /* ptrAlpha */ nullptr, /* ptrBeta */ nullptr, /* clampLimit */ nullptr, output, outputScale, + /* permutedIdxToTokenIdx */ nullptr, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, + ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm2Workspace, stream, device, configIndex, + enable_pdl); +} + +size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, + int32_t numExperts, int32_t numTokens, + int32_t configIndex) const { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + return mRunner.getWorkspaceSizeInBytes(numTokens, hiddenSize, intermediateSize, {}, numTokens, + numExperts, maxNumCtasInBatchDim, configIndex); +} + +int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, + int32_t numTokens) const { + auto maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + return mRunner.getDefaultValidConfigIndex(numTokens, hiddenSize, intermediateSize, {}, numTokens, + numExperts, maxNumCtasInBatchDim); +} + +bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, + int32_t numTokens) const { + auto const maxNumCtasInBatchDim = + Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + + auto const isValid = + mRunner.isValidConfigIndex(configIndex, numTokens, hiddenSize, intermediateSize, {}, + numTokens, numExperts, maxNumCtasInBatchDim); + + return isValid; +} + +std::vector Runner::getPassingConfigIndices() const { + return mRunner.getPassingConfigIndices(); +} +} // namespace Gemm2 + +namespace MoE { +Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, + int32_t tileTokensDim, GatedActType gatedActType, bool useShuffledMatrixA, + batchedGemm::gemm::MatrixLayout weightLayout) + : mPermuteGemm1(PermuteGemm1::Runner(dtypeAct, dtypeWeights, useDeepSeekFp8, tileTokensDim, + gatedActType, useShuffledMatrixA, weightLayout)), + mGemm2(Gemm2::Runner(dtypeAct, dtypeWeights, btg::Dtype::Bfloat16, useDeepSeekFp8, + tileTokensDim, useShuffledMatrixA, weightLayout)) { + auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices(); + auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices(); + + auto const totalPassingIndices = gemm1PassingIndices.size() * gemm2PassingIndices.size(); + mPassingConfigs.reserve(totalPassingIndices); + + for (auto const& indexGemm1 : gemm1PassingIndices) { + for (auto const& indexGemm2 : gemm2PassingIndices) { + mPassingConfigs.push_back(MoEConfig{indexGemm1, indexGemm2}); + } + } + FLASHINFER_CHECK(!mPassingConfigs.empty(), + "No compatible configs found for the fp8 block scale MoE runner."); +} + +Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim, + bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) + : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, GatedActType::SwiGlu, + useShuffledMatrixA, weightLayout) {} + +void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, + moe::dev::convertsf::Data& convertSfData, + moe::dev::activation::Data& activationData, + moe::dev::finalize::Data& finalizeData) { + // Setup sf conversion data if needed + convertSfData.inSfPtr = args.hidden_states_scale; + convertSfData.outSfPtr = workspace.hidden_states_scale_linear; + convertSfData.hiddenDimSf = args.hidden_size / 16; + convertSfData.numTokens = args.num_tokens; + convertSfData.sfLayoutSrc = btg::SfLayout::R128c4; + convertSfData.sfLayoutDst = btg::SfLayout::Linear; + convertSfData.mUsePdl = true; + + // Setup activation data + activationData.mDtypeElt = args.mDtypeElt; + activationData.mUsePdl = true; + activationData.mUseDeepSeekFp8 = true; + activationData.inPtr = workspace.gemm1_output; + activationData.outPtr = workspace.activation_output; + activationData.inDqSfsPtr = workspace.gemm1_output_scale; + activationData.outDqSfsPtr = workspace.activation_output_scale; + activationData.innerDim = args.intermediate_size * 2; + activationData.topK = args.top_k; + activationData.numTokens = args.num_tokens; + activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; + + activationData.totalNumPaddedTokens = workspace.total_num_padded_tokens; + + // Setup finalize data + if (args.do_finalize) { + // Setup finalize data + finalizeData.mDtypeElt = args.mDtypeOut; + finalizeData.mDtypeExpW = args.mDtypeExpW; + finalizeData.mUsePdl = true; + finalizeData.mUseDeepSeekFp8 = false; + finalizeData.inPtr = workspace.gemm2_output; + finalizeData.outPtr = args.output; + finalizeData.inDqSfsPtr = workspace.gemm2_output_scale; + finalizeData.outDqSfsPtr = args.output_scale; + if (args.mUseRoutingScalesOnInput) { + finalizeData.expertWeightsPtr = nullptr; + } else { + finalizeData.expertWeightsPtr = workspace.expert_weights; + } + finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; + finalizeData.numTokens = args.num_tokens; + finalizeData.numExperts = args.num_experts; + finalizeData.topK = args.top_k; + // We want to fuse unpadding into the finalize kernel, so we need to use the output hidden size. + finalizeData.hiddenDim = args.hidden_size_output.value_or(args.hidden_size); + finalizeData.hiddenDimPadded = args.hidden_size; + finalizeData.totalNumPaddedTokens = workspace.total_num_padded_tokens; + } +} + +std::tuple Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args, + int64_t configIndex) const { + auto const& config = mPassingConfigs[configIndex]; + + auto workspace_size_fc1 = static_cast(mPermuteGemm1.getWorkspaceSizeInBytes( + args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens, + config.gemm1Config)); + auto workspace_size_fc2 = static_cast( + mGemm2.getWorkspaceSizeInBytes(args.top_k, args.hidden_size, args.intermediate_size, + args.local_num_experts, args.num_tokens, config.gemm2Config)); + return std::make_tuple(workspace_size_fc1, workspace_size_fc2); +} + +std::vector Runner::getValidConfigIndices(int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, + int32_t numLocalExperts, + int32_t numTokens) const { + std::vector validIndices; + + for (int i = 0; i < mPassingConfigs.size(); ++i) { + auto const& config = mPassingConfigs[i]; + + if (mPermuteGemm1.isValidConfigIndex(config.gemm1Config, topK, hiddenSize, intermediateSize, + numLocalExperts, numTokens) && + mGemm2.isValidConfigIndex(config.gemm2Config, topK, hiddenSize, intermediateSize, + numLocalExperts, numTokens)) { + validIndices.push_back(i); + } + } + + return validIndices; +} + +int64_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numLocalExperts, + int32_t numTokens) const { + int32_t indexGemm1 = mPermuteGemm1.getDefaultValidConfigIndex(topK, hiddenSize, intermediateSize, + numLocalExperts, numTokens); + int32_t indexGemm2 = mGemm2.getDefaultValidConfigIndex(topK, hiddenSize, intermediateSize, + numLocalExperts, numTokens); + + auto it = std::find_if(mPassingConfigs.begin(), mPassingConfigs.end(), + [indexGemm1, indexGemm2](MoEConfig cfg) { + return (cfg.gemm1Config == indexGemm1 && cfg.gemm2Config == indexGemm2); + }); + FLASHINFER_CHECK(it != mPassingConfigs.end(), + "No compatible configs found for the block scale MoE runner."); + return std::distance(mPassingConfigs.begin(), it); +} + +void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, + cudaStream_t stream, int64_t configIndex, bool enable_pdl) { + // Setup all operation data + moe::dev::activation::Data activationData; + moe::dev::finalize::Data finalizeData; + moe::dev::convertsf::Data convertSfData; + sync_check_cuda_error(stream); + setOpsData(args, workspace, convertSfData, activationData, finalizeData); + + void* hidden_states_scale_linear{args.hidden_states_scale}; + + auto const& config = mPassingConfigs[configIndex]; + + mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, + args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, + args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, + args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, + workspace.gemm1_output_scale, args.top_k, args.hidden_size, + args.intermediate_size, args.local_num_experts, args.num_tokens, + workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, + workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, + args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); + + // We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint. + void* gemm2_input = workspace.gemm1_output; + void* gemm2_input_scale = workspace.gemm1_output_scale; + // We do activation only for DeepSeek FP8, as cubins do not have fused activation. + if (args.mDtypeElt == btg::Dtype::E4m3 && args.mUseDeepSeekFp8) { + // Run activation + moe::dev::activation::run(activationData, stream); + gemm2_input = workspace.activation_output; + gemm2_input_scale = workspace.activation_output_scale; + } + + // Run gemm2 + mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, + args.output2_scales_scalar, args.gemm2_bias, workspace.gemm2_output, + workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, + args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, + workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, + config.gemm2Config, enable_pdl); + + // Run finalize + if (args.do_finalize) { + // Run finalize + moe::dev::finalize::run(finalizeData, stream); + sync_check_cuda_error(stream); + } +} +} // namespace MoE + +} // namespace trtllmgen_moe +} // namespace kernels +} // namespace tensorrt_llm + +#endif // USE_FLASHINFER diff --git a/src/lib.rs b/src/lib.rs index 25b4d2f..5fd41e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -356,7 +356,7 @@ impl PagedAttention { let (_, key_value_heads, _, _) = key.shape().dims4()?; let group_size = attention_heads / key_value_heads; let flashinfer_prefill_group_supported = - matches!(group_size, 1 | 2 | 3 | 4 | 8 | 16 | 32 | 64); + matches!(group_size, 1 | 2 | 3 | 4 | 6 | 8 | 16 | 32 | 64); let flashinfer_decode_group_supported = flashinfer_prefill_group_supported && !(group_size == 64 && head_size > 128); let flashinfer_group_supported = if input_metadata.is_prefill { diff --git a/src/moe.rs b/src/moe.rs index 8fcc24c..2140fe1 100644 --- a/src/moe.rs +++ b/src/moe.rs @@ -1,8 +1,249 @@ +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +use candle_core::cuda_backend::cudarc::driver::DevicePtr; use candle_core::quantized::QTensor; use candle_core::{Result, Tensor}; #[cfg(feature = "cuda")] use kernels::ffi; +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn cuda_dtype_code(dtype: candle_core::DType) -> Result { + match dtype { + candle_core::DType::F16 => Ok(0), + candle_core::DType::BF16 => Ok(1), + _ => candle_core::bail!("only f16/bf16 are supported for flashinfer fused moe"), + } +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn cuda_ptr_f16bf16(t: &Tensor) -> Result<*const core::ffi::c_void> { + use candle_core as candle; + let (storage, _) = t.storage_and_layout(); + match (&*storage, t.dtype()) { + (candle::Storage::Cuda(c), candle_core::DType::F16) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *const core::ffi::c_void) + } + (candle::Storage::Cuda(c), candle_core::DType::BF16) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *const core::ffi::c_void) + } + _ => candle_core::bail!("expected CUDA f16/bf16 tensor"), + } +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn cuda_ptr_u8(t: &Tensor) -> Result<*const u8> { + use candle_core as candle; + let (storage, _) = t.storage_and_layout(); + match (&*storage, t.dtype()) { + (candle::Storage::Cuda(c), candle_core::DType::U8) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *const u8) + } + _ => candle_core::bail!("expected CUDA u8 tensor"), + } +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn cuda_ptr_f32(t: &Tensor) -> Result<*const f32> { + use candle_core as candle; + let (storage, _) = t.storage_and_layout(); + match (&*storage, t.dtype()) { + (candle::Storage::Cuda(c), candle_core::DType::F32) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *const f32) + } + _ => candle_core::bail!("expected CUDA f32 tensor"), + } +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn cuda_ptr_topk_ids_i32(t: &Tensor) -> Result<*const i32> { + use candle_core as candle; + let (storage, _) = t.storage_and_layout(); + match (&*storage, t.dtype()) { + (candle::Storage::Cuda(c), candle_core::DType::U32) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *const i32) + } + _ => candle_core::bail!("expected CUDA u32 tensor for topk ids"), + } +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn cuda_mut_ptr_f16bf16(t: &Tensor) -> Result<*mut core::ffi::c_void> { + use candle_core as candle; + let (storage, _) = t.storage_and_layout(); + match (&*storage, t.dtype()) { + (candle::Storage::Cuda(c), candle_core::DType::F16) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *mut core::ffi::c_void) + } + (candle::Storage::Cuda(c), candle_core::DType::BF16) => { + Ok(*c.as_cuda_slice::()?.device_ptr() as *mut core::ffi::c_void) + } + _ => candle_core::bail!("expected CUDA f16/bf16 tensor"), + } +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +pub fn flashinfer_fused_moe( + input: &Tensor, + topk_ids: &Tensor, + topk_weights: &Tensor, + gate_up_weights: &Tensor, + down_weights: &Tensor, +) -> Result { + let (num_tokens, hidden_size) = input.dims2()?; + let (num_experts, gate_up_n, gate_up_k) = gate_up_weights.dims3()?; + let (down_experts, down_n, down_k) = down_weights.dims3()?; + let (topk_tokens, top_k) = topk_ids.dims2()?; + let (topk_w_tokens, topk_w_k) = topk_weights.dims2()?; + if !input.is_contiguous() + || !topk_ids.is_contiguous() + || !topk_weights.is_contiguous() + || !gate_up_weights.is_contiguous() + || !down_weights.is_contiguous() + { + candle_core::bail!("flashinfer fused moe expects contiguous tensors"); + } + if topk_tokens != num_tokens || topk_w_tokens != num_tokens || topk_w_k != top_k { + candle_core::bail!("flashinfer fused moe: invalid topk tensors"); + } + if gate_up_k != hidden_size || down_experts != num_experts || down_n != hidden_size { + candle_core::bail!("flashinfer fused moe: invalid tensor shapes for moe weights"); + } + if gate_up_n % 2 != 0 { + candle_core::bail!("flashinfer fused moe: gate_up second dim must be even"); + } + if down_k * 2 != gate_up_n { + candle_core::bail!("flashinfer fused moe: gate_up/down intermediate dims mismatch"); + } + let input_dtype = cuda_dtype_code(input.dtype())?; + let weight_dtype = cuda_dtype_code(gate_up_weights.dtype())?; + if input_dtype != weight_dtype { + candle_core::bail!("flashinfer fused moe: input and weight dtype must match"); + } + let dev = input.device().as_cuda_device()?; + let stream = *dev.cu_stream() as i64; + + let output = Tensor::zeros((num_tokens, hidden_size), input.dtype(), input.device())?; + let status = unsafe { + ffi::flashinfer_fused_moe_bf16( + cuda_ptr_f16bf16(input)?, + cuda_ptr_topk_ids_i32(topk_ids)?, + cuda_ptr_f32(topk_weights)?, + cuda_ptr_f16bf16(gate_up_weights)?, + cuda_ptr_f16bf16(down_weights)?, + cuda_mut_ptr_f16bf16(&output)?, + num_tokens as i32, + hidden_size as i32, + down_k as i32, + num_experts as i32, + top_k as i32, + input_dtype, + weight_dtype, + stream, + ) + }; + if status != 0 { + candle_core::bail!("flashinfer fused moe bf16 kernel failed with status {status}"); + } + Ok(output) +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +pub fn flashinfer_fused_moe_fp8( + input: &Tensor, + topk_ids: &Tensor, + topk_weights: &Tensor, + gate_up_weights: &Tensor, + gate_up_scales: &Tensor, + down_weights: &Tensor, + down_scales: &Tensor, +) -> Result { + let (num_tokens, hidden_size) = input.dims2()?; + let (num_experts, gate_up_n, gate_up_k) = gate_up_weights.dims3()?; + let (down_experts, down_n, down_k) = down_weights.dims3()?; + let (gate_up_scale_experts, gate_up_scale_n, gate_up_scale_k) = gate_up_scales.dims3()?; + let (down_scale_experts, down_scale_n, down_scale_k) = down_scales.dims3()?; + let (topk_tokens, top_k) = topk_ids.dims2()?; + let (topk_w_tokens, topk_w_k) = topk_weights.dims2()?; + if !input.is_contiguous() + || !topk_ids.is_contiguous() + || !topk_weights.is_contiguous() + || !gate_up_weights.is_contiguous() + || !gate_up_scales.is_contiguous() + || !down_weights.is_contiguous() + || !down_scales.is_contiguous() + { + candle_core::bail!("flashinfer fused moe fp8 expects contiguous tensors"); + } + if gate_up_weights.dtype() != candle_core::DType::U8 + || down_weights.dtype() != candle_core::DType::U8 + || gate_up_scales.dtype() != candle_core::DType::F32 + || down_scales.dtype() != candle_core::DType::F32 + { + candle_core::bail!("flashinfer fused moe fp8 expects u8 weights and f32 scales"); + } + if topk_tokens != num_tokens || topk_w_tokens != num_tokens || topk_w_k != top_k { + candle_core::bail!("flashinfer fused moe fp8: invalid topk tensors"); + } + if gate_up_k != hidden_size || down_experts != num_experts || down_n != hidden_size { + candle_core::bail!("flashinfer fused moe fp8: invalid tensor shapes for moe weights"); + } + if gate_up_n % 2 != 0 || down_k * 2 != gate_up_n { + candle_core::bail!("flashinfer fused moe fp8: gate_up/down intermediate dims mismatch"); + } + if gate_up_scale_experts != num_experts || down_scale_experts != num_experts { + candle_core::bail!("flashinfer fused moe fp8: scale tensor expert dim mismatch"); + } + if hidden_size % 128 != 0 || down_k % 128 != 0 { + candle_core::bail!( + "flashinfer fused moe fp8: hidden/intermediate dims must be divisible by 128" + ); + } + let expected_gate_up_scale_n = gate_up_n / 128; + let expected_gate_up_scale_k = hidden_size / 128; + let expected_down_scale_n = hidden_size / 128; + let expected_down_scale_k = down_k / 128; + if gate_up_scale_n != expected_gate_up_scale_n + || gate_up_scale_k != expected_gate_up_scale_k + || down_scale_n != expected_down_scale_n + || down_scale_k != expected_down_scale_k + { + candle_core::bail!( + "flashinfer fused moe fp8: invalid scale tensor shapes, expected gate_up=[{num_experts}, {expected_gate_up_scale_n}, {expected_gate_up_scale_k}], down=[{num_experts}, {expected_down_scale_n}, {expected_down_scale_k}]" + ); + } + let input_dtype = cuda_dtype_code(input.dtype())?; + let dev = input.device().as_cuda_device()?; + let stream = *dev.cu_stream() as i64; + + let output = Tensor::zeros( + (num_tokens, hidden_size), + candle_core::DType::BF16, + input.device(), + )?; + let status = unsafe { + ffi::flashinfer_fused_moe_fp8( + cuda_ptr_f16bf16(input)?, + cuda_ptr_topk_ids_i32(topk_ids)?, + cuda_ptr_f32(topk_weights)?, + cuda_ptr_u8(gate_up_weights)?, + cuda_ptr_f32(gate_up_scales)?, + cuda_ptr_u8(down_weights)?, + cuda_ptr_f32(down_scales)?, + cuda_mut_ptr_f16bf16(&output)?, + num_tokens as i32, + hidden_size as i32, + down_k as i32, + num_experts as i32, + top_k as i32, + input_dtype, + stream, + ) + }; + if status != 0 { + candle_core::bail!("flashinfer fused moe fp8 kernel failed with status {status}"); + } + Ok(output) +} + #[cfg(feature = "cuda")] pub fn moe_gemm( input: &Tensor, From d4de938bdcaac3169e1fc11554853b410477d9f6 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 6 Mar 2026 11:33:27 +0000 Subject: [PATCH 2/7] Use flashattn.rs (v2/v3 all in one crate) --- Cargo.toml | 9 +++++---- src/lib.rs | 36 +++++++++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 882b16c..0a5ad22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ license = "MIT" [dependencies] candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" } candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" } -candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9", optional = true } +#candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9", optional = true } serde = { version = "1.0.190", features = ["serde_derive"] } serde_json = "1.0.108" half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } @@ -24,13 +24,14 @@ rayon="1.10.0" kernels = { path = "./src/kernels", version="0.4.2", optional = true} metal = { version = "0.27.0", features = ["mps"], optional = true } metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true} +flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "3955e82", optional = true } [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:kernels"] graph = ["cuda", "candle-core/graph"] -flash-attn = ["dep:candle-flash-attn"] -flash-decoding = ["dep:candle-flash-attn", "candle-flash-attn/flash-decoding", "kernels/no-fp8-kvcache"] -flash-context = ["dep:candle-flash-attn", "candle-flash-attn/flash-context", "kernels/no-fp8-kvcache"] +flash-attn = ["dep:flashattn-rs"] +flash-decoding = ["dep:flashattn-rs", "flashattn-rs/flash-decoding", "kernels/no-fp8-kvcache"] +flash-context = ["dep:flashattn-rs", "flashattn-rs/flash-context", "kernels/no-fp8-kvcache"] no-marlin = ["dep:kernels", "kernels/no-marlin"] no-fp8-kvcache = ["dep:kernels", "kernels/no-fp8-kvcache"] metal = ["candle-core/metal", "candle-nn/metal", "dep:metal-kernels", "dep:metal"] diff --git a/src/lib.rs b/src/lib.rs index 5fd41e3..410be0c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -232,7 +232,7 @@ impl PagedAttention { softcapping: Option, ) -> Result { if self.sliding_window.is_some() { - candle_flash_attn::flash_attn_varlen_windowed_softcap( + flashattn_rs::flash_attn_varlen_windowed_softcap( query, key, value, @@ -247,7 +247,7 @@ impl PagedAttention { Some(0), ) } else { - candle_flash_attn::flash_attn_varlen_softcap( + flashattn_rs::flash_attn_varlen_softcap( query, key, value, @@ -303,8 +303,20 @@ impl PagedAttention { return if input_metadata.block_tables.is_none() { // prefill without kvcache self.flash_var_len(&query, &key, &value, input_metadata, softcapping) + } else if self.sliding_window.is_none() { + flashattn_rs::flash_attn_with_kvcache_full( + &query, + key_cache.as_ref().unwrap(), + value_cache.as_ref().unwrap(), + input_metadata.context_lens.as_ref().unwrap(), + input_metadata.block_tables.as_ref().unwrap(), + input_metadata.cu_seqlens_q.as_ref(), + Some(input_metadata.max_seqlen_q), + self.scale as f32, + true, + ) } else { - // prefill with kvcache + // Sliding-window prefill still needs the windowed varlen path. self.flash_var_len( &query, key_cache.as_ref().unwrap(), @@ -319,17 +331,27 @@ impl PagedAttention { { let block_tables = input_metadata.block_tables.as_ref().unwrap(); let context_lens = input_metadata.context_lens.as_ref().unwrap(); - candle_flash_attn::flash_attn_with_kvcache_windowed_softcap( + + flashattn_rs::flash_attn_with_kvcache( &query.unsqueeze(1)?, //(batch_size, seqlen_q, num_heads_q, head_size) key_cache.as_ref().unwrap(), value_cache.as_ref().unwrap(), context_lens, block_tables, self.scale as f32, - Some(softcapping.unwrap_or(0.0f64) as f32), - self.sliding_window, - Some(0), ) + + // flashattn_rs::flash_attn_with_kvcache_windowed_softcap( + // &query.unsqueeze(1)?, //(batch_size, seqlen_q, num_heads_q, head_size) + // key_cache.as_ref().unwrap(), + // value_cache.as_ref().unwrap(), + // context_lens, + // block_tables, + // self.scale as f32, + // Some(softcapping.unwrap_or(0.0f64) as f32), + // self.sliding_window, + // Some(0), + // ) } #[cfg(not(feature = "flash-decoding"))] candle_core::bail!("Invalid pattern for flash_forward") From 92a03c0f5045296ad72d37ae769e2dbcffaa7df2 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Mon, 9 Mar 2026 10:00:08 +0000 Subject: [PATCH 3/7] Fix softcap and sliding_window --- Cargo.toml | 2 +- src/gdn.rs | 1 + src/lib.rs | 78 ++++++++++++++++++++++++++++-------------------------- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dce7548..998c565 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ rayon="1.10.0" kernels = { path = "./src/kernels", version="0.4.2", optional = true} metal = { version = "0.27.0", features = ["mps"], optional = true } metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true} -flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "3955e82", optional = true } +flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "9e3e649", optional = true } [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:kernels"] diff --git a/src/gdn.rs b/src/gdn.rs index fd1a708..4db1595 100644 --- a/src/gdn.rs +++ b/src/gdn.rs @@ -1371,6 +1371,7 @@ fn gated_delta_rule_recurrence_naive( Tensor::cat(&output_refs, 1)?.to_dtype(out_dtype) } +#[cfg(not(feature = "cuda"))] fn gated_delta_rule_decode_slots_naive( q: &Tensor, k: &Tensor, diff --git a/src/lib.rs b/src/lib.rs index 410be0c..95329ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -286,6 +286,8 @@ impl PagedAttention { let value = value .transpose(1, 2)? .reshape(((), key_value_heads, head_size))?; + let softcap = Some(softcapping.unwrap_or(0.0f64) as f32); + let window_size_right = self.sliding_window.map(|_| 0); self.maybe_update_kv_scales(&key, &value)?; @@ -299,32 +301,35 @@ impl PagedAttention { &slot_mapping, )?; + if input_metadata.is_prefill && input_metadata.block_tables.is_none() { + // prefill without kvcache + return self.flash_var_len(&query, &key, &value, input_metadata, softcapping); + } + + #[cfg(feature = "flash-decoding")] if input_metadata.is_prefill { - return if input_metadata.block_tables.is_none() { - // prefill without kvcache - self.flash_var_len(&query, &key, &value, input_metadata, softcapping) - } else if self.sliding_window.is_none() { - flashattn_rs::flash_attn_with_kvcache_full( - &query, - key_cache.as_ref().unwrap(), - value_cache.as_ref().unwrap(), - input_metadata.context_lens.as_ref().unwrap(), - input_metadata.block_tables.as_ref().unwrap(), - input_metadata.cu_seqlens_q.as_ref(), - Some(input_metadata.max_seqlen_q), - self.scale as f32, - true, - ) - } else { - // Sliding-window prefill still needs the windowed varlen path. - self.flash_var_len( - &query, - key_cache.as_ref().unwrap(), - value_cache.as_ref().unwrap(), - input_metadata, - softcapping, - ) - }; + return flashattn_rs::flash_attn_with_kvcache_advanced( + &query, + key_cache.as_ref().unwrap(), + value_cache.as_ref().unwrap(), + input_metadata.context_lens.as_ref().unwrap(), + input_metadata.block_tables.as_ref().unwrap(), + input_metadata.cu_seqlens_q.as_ref(), + Some(input_metadata.max_seqlen_q), + self.scale as f32, + true, + self.sliding_window, + window_size_right, + None, + softcap, + 0, + None, + ); + } + + #[cfg(not(feature = "flash-decoding"))] + if input_metadata.is_prefill { + candle_core::bail!("Invalid pattern for flash_forward"); } #[cfg(feature = "flash-decoding")] @@ -332,26 +337,23 @@ impl PagedAttention { let block_tables = input_metadata.block_tables.as_ref().unwrap(); let context_lens = input_metadata.context_lens.as_ref().unwrap(); - flashattn_rs::flash_attn_with_kvcache( + flashattn_rs::flash_attn_with_kvcache_advanced( &query.unsqueeze(1)?, //(batch_size, seqlen_q, num_heads_q, head_size) key_cache.as_ref().unwrap(), value_cache.as_ref().unwrap(), context_lens, block_tables, + None, + None, self.scale as f32, + false, + self.sliding_window, + window_size_right, + None, + softcap, + 0, + None, ) - - // flashattn_rs::flash_attn_with_kvcache_windowed_softcap( - // &query.unsqueeze(1)?, //(batch_size, seqlen_q, num_heads_q, head_size) - // key_cache.as_ref().unwrap(), - // value_cache.as_ref().unwrap(), - // context_lens, - // block_tables, - // self.scale as f32, - // Some(softcapping.unwrap_or(0.0f64) as f32), - // self.sliding_window, - // Some(0), - // ) } #[cfg(not(feature = "flash-decoding"))] candle_core::bail!("Invalid pattern for flash_forward") From 53ab8968c5191509aacf819b334e980d7dc44a95 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 10 Mar 2026 11:19:31 +0000 Subject: [PATCH 4/7] Support deep_gemm on Hopper --- Cargo.toml | 8 +- src/fp8_linear.rs | 173 +++- src/fused_rope.rs | 1090 ++++++++++++++----------- src/kernels/Cargo.toml | 2 +- src/kernels/build.rs | 38 + src/kernels/src/ffi.rs | 198 +++++ src/kernels/src/flashinfer_bmm_fp8.cu | 159 ++++ src/kernels/src/fused_rope.cu | 744 +++++++++++++++++ src/lib.rs | 161 +++- src/metal-kernels/Cargo.toml | 2 +- 10 files changed, 2068 insertions(+), 507 deletions(-) create mode 100644 src/kernels/src/flashinfer_bmm_fp8.cu diff --git a/Cargo.toml b/Cargo.toml index 998c565..c916b4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,9 +12,9 @@ categories = ["algorithms", "hardware-support", "science"] license = "MIT" [dependencies] -candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } -candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } -#candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048", optional = true } +candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038" } +candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038" } +#candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038", optional = true } serde = { version = "1.0.190", features = ["serde_derive"] } serde_json = "1.0.108" half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } @@ -24,7 +24,7 @@ rayon="1.10.0" kernels = { path = "./src/kernels", version="0.4.2", optional = true} metal = { version = "0.27.0", features = ["mps"], optional = true } metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true} -flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "9e3e649", optional = true } +flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "bf1db0a", optional = true } [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:kernels"] diff --git a/src/fp8_linear.rs b/src/fp8_linear.rs index 27e4ec5..0be5920 100644 --- a/src/fp8_linear.rs +++ b/src/fp8_linear.rs @@ -1,14 +1,61 @@ -#[cfg(all(feature = "cuda", feature = "cutlass"))] +#[cfg(feature = "cuda")] use crate::cuda_utils; #[cfg(feature = "cuda")] use crate::kernels::ffi; #[cfg(feature = "metal")] use crate::metal_kernels; #[cfg(feature = "cuda")] +use candle_core::cuda_backend::cudarc::driver::CudaSlice; +#[cfg(feature = "cuda")] use candle_core::cuda_backend::cudarc::driver::DevicePtr; +#[cfg(feature = "cuda")] +use candle_core::cuda_backend::WrapErr; use candle_core::{DType, Device, Result, Tensor}; +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +use std::cell::RefCell; + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +struct FlashInferFp8Workspace { + buffer: CudaSlice, + size: usize, + device_ordinal: usize, +} -#[cfg(all(feature = "cuda", feature = "cutlass"))] +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +thread_local! { + static FLASHINFER_FP8_WORKSPACE: RefCell> = const { RefCell::new(None) }; +} + +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +fn get_or_init_flashinfer_fp8_workspace( + dev: &candle_core::cuda_backend::CudaDevice, + required_size: usize, +) -> Result<(*mut std::ffi::c_void, usize)> { + FLASHINFER_FP8_WORKSPACE.with(|cell| { + let mut slot = cell.borrow_mut(); + let ordinal = dev.ordinal(); + + let needs_init = match slot.as_ref() { + None => true, + Some(existing) => existing.device_ordinal != ordinal || existing.size < required_size, + }; + + if needs_init { + let alloc_size = required_size.max(1); + let buffer = unsafe { dev.alloc::(alloc_size) }.w()?; + *slot = Some(FlashInferFp8Workspace { + buffer, + size: alloc_size, + device_ordinal: ordinal, + }); + } + + let ws = slot.as_ref().unwrap(); + Ok((*ws.buffer.device_ptr() as *mut std::ffi::c_void, ws.size)) + }) +} + +#[cfg(feature = "cuda")] fn get_cuda_slice< T: candle_core::cuda_backend::cudarc::driver::DeviceRepr + candle_core::cuda_backend::CudaDType, >( @@ -222,6 +269,128 @@ pub fn fp8_matmul( Ok(output) } +/// FP8 Matrix Multiplication using FlashInfer/TensorRT-LLM SM90 blockwise GEMM. +/// +/// This path expects Hopper-native blockwise scales in `[N/128, K/128]` layout and +/// relies on the underlying runner's small-`M` swapAB optimization for decode. +#[cfg(all(feature = "cuda", feature = "flashinfer"))] +pub fn fp8_matmul_flashinfer( + input: &Tensor, + weight: &Tensor, + weight_scale: &Tensor, +) -> Result { + let (m, k) = input.dims2()?; + let (n, k_w) = weight.dims2()?; + + if k != k_w { + candle_core::bail!( + "Shape mismatch in fp8_matmul_flashinfer: input [{}, {}], weight [{}, {}]", + m, + k, + n, + k_w + ); + } + + if input.dtype() != DType::BF16 { + candle_core::bail!("fp8_matmul_flashinfer requires bf16 input"); + } + if weight.dtype() != DType::U8 || weight_scale.dtype() != DType::F32 { + candle_core::bail!("fp8_matmul_flashinfer requires u8 weights and f32 scales"); + } + if !input.is_contiguous() { + candle_core::bail!("fp8_matmul_flashinfer requires contiguous input"); + } + if !weight.is_contiguous() { + candle_core::bail!("fp8_matmul_flashinfer requires contiguous row-major weight"); + } + if !weight_scale.is_contiguous() { + candle_core::bail!("fp8_matmul_flashinfer requires contiguous row-major weight_scale"); + } + if k % 128 != 0 { + candle_core::bail!("fp8_matmul_flashinfer requires K divisible by 128"); + } + if n % 64 != 0 { + candle_core::bail!("fp8_matmul_flashinfer requires N divisible by 64"); + } + + let expected_scale = ((n + 127) / 128, k / 128); + if weight_scale.dims2()? != expected_scale { + candle_core::bail!( + "fp8_matmul_flashinfer expects weight_scale shape [{}, {}], got {:?}", + expected_scale.0, + expected_scale.1, + weight_scale.dims() + ); + } + + let dev = input.device(); + let sm_version = cuda_utils::sm_version(dev.as_cuda_device()?).unwrap_or(0) as usize; + if !(90..100).contains(&sm_version) { + candle_core::bail!("fp8_matmul_flashinfer requires Hopper (sm90)"); + } + + let cu_dev = dev.as_cuda_device()?; + let stream = *cu_dev.cu_stream() as i64; + let m_padded = (m + 4 - 1) / 4 * 4; + let out = Tensor::zeros((m, n), DType::BF16, dev)?; + let k_over_128 = k / 128; + let input_q = Tensor::zeros((m, k), DType::U8, dev)?; + // FlashInfer/DeepGEMM expects scales_a to use an M-aligned leading stride. + // Their own tests allocate [K/128, M_padded] and treat only the first M columns as live. + let input_scale = Tensor::zeros((k_over_128, m_padded), DType::F32, dev)?; + let scale_stride = input_scale.stride()[0] as i32; + let q_ptr = get_cuda_slice::(&input_q)? as *mut std::ffi::c_void; + let s_ptr = get_cuda_slice::(&input_scale)? as *mut f32; + let inp_ptr = get_cuda_slice::(input)? as *const std::ffi::c_void; + + unsafe { + let num_groups = m * k_over_128; + ffi::fp8_quantize_per_token_group_launch( + inp_ptr, + q_ptr, + s_ptr, + num_groups as i32, + 128, + k_over_128 as i32, + scale_stride, + false, + true, + stream, + ); + } + + let required_ws = + unsafe { ffi::flashinfer_fp8_blockscale_workspace_size_fp8(m as i32, n as i32, k as i32) }; + let (workspace_ptr, workspace_size) = + get_or_init_flashinfer_fp8_workspace(cu_dev, required_ws)?; + + let weight_ptr = get_cuda_slice::(weight)? as *const std::ffi::c_void; + let weight_scale_ptr = get_cuda_slice::(weight_scale)? as *const f32; + let out_ptr = get_cuda_slice::(&out)? as *mut std::ffi::c_void; + + let status = unsafe { + ffi::flashinfer_fp8_blockscale_fp8( + q_ptr as *const std::ffi::c_void, + s_ptr as *const f32, + weight_ptr, + weight_scale_ptr, + out_ptr, + m as i32, + n as i32, + k as i32, + workspace_ptr, + workspace_size, + stream, + ) + }; + if status != 0 { + candle_core::bail!("flashinfer fp8 blockscale gemm failed with status {status}"); + } + + Ok(out) +} + /// FP8 Matrix Multiplication using CUTLASS blockwise kernels (SM90+). /// /// # Arguments diff --git a/src/fused_rope.rs b/src/fused_rope.rs index 2e01b52..bb42d89 100644 --- a/src/fused_rope.rs +++ b/src/fused_rope.rs @@ -12,6 +12,605 @@ use candle_core::{DType, Result, Tensor}; #[cfg(feature = "cuda")] use kernels::ffi; +#[cfg(feature = "cuda")] +#[derive(Clone, Copy)] +enum RopeLayout { + BatchMajor { + q_bh: u32, + k_bh: u32, + seq_len: u32, + d: u32, + }, + TokenMajor { + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + }, +} + +#[cfg(feature = "cuda")] +impl RopeLayout { + fn positions_len(self) -> usize { + match self { + Self::BatchMajor { seq_len, .. } => seq_len as usize, + Self::TokenMajor { num_tokens, .. } => num_tokens as usize, + } + } +} + +#[cfg(feature = "cuda")] +fn resolve_rope_layout(q: &Tensor, k: &Tensor) -> Result { + match (q.dims().len(), k.dims().len()) { + (4, 4) => { + let (b, q_h, seq_len, d) = q.dims4()?; + let (kb, k_h, k_seq_len, kd) = k.dims4()?; + if b != kb || seq_len != k_seq_len || d != kd { + candle_core::bail!( + "Q and K batch/seq_len/head_dim must match, got Q: {:?}, K: {:?}", + q.shape(), + k.shape() + ); + } + Ok(RopeLayout::BatchMajor { + q_bh: (b * q_h) as u32, + k_bh: (b * k_h) as u32, + seq_len: seq_len as u32, + d: d as u32, + }) + } + (3, 3) => { + let (num_tokens, q_heads, d) = q.dims3()?; + let (k_num_tokens, k_heads, kd) = k.dims3()?; + if num_tokens != k_num_tokens || d != kd { + candle_core::bail!( + "Q and K num_tokens/head_dim must match, got Q: {:?}, K: {:?}", + q.shape(), + k.shape() + ); + } + Ok(RopeLayout::TokenMajor { + num_tokens: num_tokens as u32, + q_heads: q_heads as u32, + k_heads: k_heads as u32, + d: d as u32, + }) + } + _ => candle_core::bail!( + "FusedRope expects Q and K to be both 4D [batch, heads, seq, dim] or both 3D [tokens, heads, dim], got Q: {:?}, K: {:?}", + q.shape(), + k.shape() + ), + } +} + +#[cfg(feature = "cuda")] +fn launch_fused_rope( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor, + positions: &Tensor, + is_interleaved: bool, +) -> Result<()> { + use candle_core::cuda_backend::cudarc::driver::DevicePtr; + use candle_core::cuda_backend::CudaStorageSlice; + + let layout = resolve_rope_layout(q, k)?; + let expected_positions_len = layout.positions_len(); + let pos_shape = positions.dims(); + if pos_shape.len() != 1 || pos_shape[0] != expected_positions_len { + candle_core::bail!( + "positions should be [{}], got {:?}", + expected_positions_len, + pos_shape + ); + } + + let positions = if positions.dtype() != DType::I64 { + positions.to_dtype(DType::I64)? + } else { + positions.clone() + }; + + if !q.is_contiguous() + || !k.is_contiguous() + || !cos.is_contiguous() + || !sin.is_contiguous() + || !positions.is_contiguous() + { + candle_core::bail!("All tensors (q, k, cos, sin, positions) must be contiguous"); + } + + let dtype = q.dtype(); + if k.dtype() != dtype || cos.dtype() != dtype || sin.dtype() != dtype { + candle_core::bail!( + "Q, K, cos, sin must have same dtype, got Q: {:?}, K: {:?}, cos: {:?}, sin: {:?}", + q.dtype(), + k.dtype(), + cos.dtype(), + sin.dtype() + ); + } + + let dev = q.device().as_cuda_device()?; + let stream = *dev.cu_stream() as i64; + + let q_storage = q.storage_and_layout().0; + let k_storage = k.storage_and_layout().0; + let cos_storage = cos.storage_and_layout().0; + let sin_storage = sin.storage_and_layout().0; + let pos_storage = positions.storage_and_layout().0; + + let q_cuda = match &*q_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("Q must be on CUDA"), + }; + let k_cuda = match &*k_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("K must be on CUDA"), + }; + let cos_cuda = match &*cos_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("cos must be on CUDA"), + }; + let sin_cuda = match &*sin_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("sin must be on CUDA"), + }; + let pos_cuda = match &*pos_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("positions must be on CUDA"), + }; + + let pos_ptr = match &pos_cuda.slice { + CudaStorageSlice::I64(s) => *s.device_ptr() as *const i64, + _ => candle_core::bail!("positions must be I64"), + }; + + match dtype { + DType::F32 => { + let q_ptr = match &q_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, + _ => candle_core::bail!("Expected F32"), + }; + let k_ptr = match &k_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, + _ => candle_core::bail!("Expected F32"), + }; + let cos_ptr = match &cos_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, + _ => candle_core::bail!("Expected F32"), + }; + let sin_ptr = match &sin_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, + _ => candle_core::bail!("Expected F32"), + }; + + unsafe { + match layout { + RopeLayout::BatchMajor { + q_bh, + k_bh, + seq_len, + d, + } => { + if is_interleaved { + ffi::fused_rope_i_f32( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, q_bh, k_bh, seq_len, d, + stream, + ); + } else { + ffi::fused_rope_f32( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, q_bh, k_bh, seq_len, d, + stream, + ); + } + } + RopeLayout::TokenMajor { + num_tokens, + q_heads, + k_heads, + d, + } => { + if is_interleaved { + ffi::fused_rope_i_tok_major_f32( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, num_tokens, q_heads, + k_heads, d, stream, + ); + } else { + ffi::fused_rope_tok_major_f32( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, num_tokens, q_heads, + k_heads, d, stream, + ); + } + } + } + } + } + DType::F16 => { + let q_ptr = match &q_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + let k_ptr = match &k_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + let cos_ptr = match &cos_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + let sin_ptr = match &sin_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + + unsafe { + match layout { + RopeLayout::BatchMajor { + q_bh, + k_bh, + seq_len, + d, + } => { + if is_interleaved { + ffi::fused_rope_i_f16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, q_bh, k_bh, seq_len, d, + stream, + ); + } else { + ffi::fused_rope_f16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, q_bh, k_bh, seq_len, d, + stream, + ); + } + } + RopeLayout::TokenMajor { + num_tokens, + q_heads, + k_heads, + d, + } => { + if is_interleaved { + ffi::fused_rope_i_tok_major_f16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, num_tokens, q_heads, + k_heads, d, stream, + ); + } else { + ffi::fused_rope_tok_major_f16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, num_tokens, q_heads, + k_heads, d, stream, + ); + } + } + } + } + } + DType::BF16 => { + let q_ptr = match &q_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + let k_ptr = match &k_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + let cos_ptr = match &cos_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + let sin_ptr = match &sin_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + + unsafe { + match layout { + RopeLayout::BatchMajor { + q_bh, + k_bh, + seq_len, + d, + } => { + if is_interleaved { + ffi::fused_rope_i_bf16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, q_bh, k_bh, seq_len, d, + stream, + ); + } else { + ffi::fused_rope_bf16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, q_bh, k_bh, seq_len, d, + stream, + ); + } + } + RopeLayout::TokenMajor { + num_tokens, + q_heads, + k_heads, + d, + } => { + if is_interleaved { + ffi::fused_rope_i_tok_major_bf16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, num_tokens, q_heads, + k_heads, d, stream, + ); + } else { + ffi::fused_rope_tok_major_bf16( + q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr, num_tokens, q_heads, + k_heads, d, stream, + ); + } + } + } + } + } + _ => candle_core::bail!("FusedRope only supports F32, F16, BF16, got {:?}", dtype), + } + + Ok(()) +} + +#[cfg(feature = "cuda")] +fn launch_fused_rope_partial_token_major( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor, + positions: &Tensor, + is_interleaved: bool, + rotary_dim: usize, +) -> Result<()> { + use candle_core::cuda_backend::cudarc::driver::DevicePtr; + use candle_core::cuda_backend::CudaStorageSlice; + + let (num_tokens, q_heads, full_d) = q.dims3()?; + let (k_num_tokens, k_heads, k_d) = k.dims3()?; + if num_tokens != k_num_tokens || full_d != k_d { + candle_core::bail!( + "Q and K num_tokens/head_dim must match, got Q: {:?}, K: {:?}", + q.shape(), + k.shape() + ); + } + if rotary_dim == 0 || rotary_dim > full_d || rotary_dim % 2 != 0 { + candle_core::bail!( + "partial fused rope requires even rotary_dim in 1..={}, got {}", + full_d, + rotary_dim + ); + } + if positions.dims() != [num_tokens] { + candle_core::bail!( + "positions should be [{}], got {:?}", + num_tokens, + positions.dims() + ); + } + if cos.dims().len() != 2 || sin.dims().len() != 2 { + candle_core::bail!( + "cos/sin should be 2D full tables, got cos {:?}, sin {:?}", + cos.shape(), + sin.shape() + ); + } + + let positions = if positions.dtype() != DType::I64 { + positions.to_dtype(DType::I64)? + } else { + positions.clone() + }; + + if !q.is_contiguous() + || !k.is_contiguous() + || !cos.is_contiguous() + || !sin.is_contiguous() + || !positions.is_contiguous() + { + candle_core::bail!("All tensors (q, k, cos, sin, positions) must be contiguous"); + } + + let dtype = q.dtype(); + if k.dtype() != dtype || cos.dtype() != dtype || sin.dtype() != dtype { + candle_core::bail!( + "Q, K, cos, sin must have same dtype, got Q: {:?}, K: {:?}, cos: {:?}, sin: {:?}", + q.dtype(), + k.dtype(), + cos.dtype(), + sin.dtype() + ); + } + + let dev = q.device().as_cuda_device()?; + let stream = *dev.cu_stream() as i64; + + let q_storage = q.storage_and_layout().0; + let k_storage = k.storage_and_layout().0; + let cos_storage = cos.storage_and_layout().0; + let sin_storage = sin.storage_and_layout().0; + let pos_storage = positions.storage_and_layout().0; + + let q_cuda = match &*q_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("Q must be on CUDA"), + }; + let k_cuda = match &*k_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("K must be on CUDA"), + }; + let cos_cuda = match &*cos_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("cos must be on CUDA"), + }; + let sin_cuda = match &*sin_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("sin must be on CUDA"), + }; + let pos_cuda = match &*pos_storage { + candle_core::Storage::Cuda(s) => s, + _ => candle_core::bail!("positions must be on CUDA"), + }; + + let pos_ptr = match &pos_cuda.slice { + CudaStorageSlice::I64(s) => *s.device_ptr() as *const i64, + _ => candle_core::bail!("positions must be I64"), + }; + + match dtype { + DType::F32 => { + let q_ptr = match &q_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, + _ => candle_core::bail!("Expected F32"), + }; + let k_ptr = match &k_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, + _ => candle_core::bail!("Expected F32"), + }; + let cos_ptr = match &cos_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, + _ => candle_core::bail!("Expected F32"), + }; + let sin_ptr = match &sin_cuda.slice { + CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, + _ => candle_core::bail!("Expected F32"), + }; + unsafe { + if is_interleaved { + ffi::fused_rope_i_partial_tok_major_f32( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + num_tokens as u32, + q_heads as u32, + k_heads as u32, + rotary_dim as u32, + full_d as u32, + stream, + ); + } else { + ffi::fused_rope_partial_tok_major_f32( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + num_tokens as u32, + q_heads as u32, + k_heads as u32, + rotary_dim as u32, + full_d as u32, + stream, + ); + } + } + } + DType::F16 => { + let q_ptr = match &q_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + let k_ptr = match &k_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + let cos_ptr = match &cos_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + let sin_ptr = match &sin_cuda.slice { + CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected F16"), + }; + unsafe { + if is_interleaved { + ffi::fused_rope_i_partial_tok_major_f16( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + num_tokens as u32, + q_heads as u32, + k_heads as u32, + rotary_dim as u32, + full_d as u32, + stream, + ); + } else { + ffi::fused_rope_partial_tok_major_f16( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + num_tokens as u32, + q_heads as u32, + k_heads as u32, + rotary_dim as u32, + full_d as u32, + stream, + ); + } + } + } + DType::BF16 => { + let q_ptr = match &q_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + let k_ptr = match &k_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + let cos_ptr = match &cos_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + let sin_ptr = match &sin_cuda.slice { + CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, + _ => candle_core::bail!("Expected BF16"), + }; + unsafe { + if is_interleaved { + ffi::fused_rope_i_partial_tok_major_bf16( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + num_tokens as u32, + q_heads as u32, + k_heads as u32, + rotary_dim as u32, + full_d as u32, + stream, + ); + } else { + ffi::fused_rope_partial_tok_major_bf16( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + num_tokens as u32, + q_heads as u32, + k_heads as u32, + rotary_dim as u32, + full_d as u32, + stream, + ); + } + } + } + _ => candle_core::bail!("FusedRope only supports F32, F16, BF16, got {:?}", dtype), + } + + Ok(()) +} + /// Fused Rotary Position Embedding /// /// Applies rotary position embedding to Q and K tensors using optimized CUDA kernels. @@ -25,10 +624,13 @@ impl FusedRope { /// /// # Arguments /// * `q` - Query tensor, shape [batch, num_q_heads, seq_len, head_dim] + /// or packed [num_tokens, num_q_heads, head_dim] /// * `k` - Key tensor, shape [batch, num_kv_heads, seq_len, head_dim] + /// or packed [num_tokens, num_kv_heads, head_dim] /// * `cos` - FULL cosine table, shape [max_seq_len, head_dim/2] /// * `sin` - FULL sine table, shape [max_seq_len, head_dim/2] - /// * `positions` - Position indices, shape [seq_len] (i64) + /// * `positions` - Position indices, shape [seq_len] for 4D inputs or + /// [num_tokens] for packed token-major inputs /// * `is_interleaved` - If true, uses interleaved layout (adjacent pairs) /// /// # Returns @@ -42,254 +644,7 @@ impl FusedRope { positions: &Tensor, is_interleaved: bool, ) -> Result<(Tensor, Tensor)> { - use candle_core::cuda_backend::cudarc::driver::DevicePtr; - use candle_core::cuda_backend::CudaStorageSlice; - - // Validate inputs - Q and K can have different head counts (GQA) - let (b, q_h, seq_len, d) = q.dims4()?; - let (kb, k_h, k_seq_len, kd) = k.dims4()?; - - if b != kb || seq_len != k_seq_len || d != kd { - candle_core::bail!( - "Q and K batch/seq_len/head_dim must match, got Q: {:?}, K: {:?}", - q.shape(), - k.shape() - ); - } - - // Positions should be 1D with length seq_len - let pos_shape = positions.dims(); - if pos_shape.len() != 1 || pos_shape[0] != seq_len { - candle_core::bail!( - "positions should be [seq_len], got {:?}, expected [{}]", - pos_shape, - seq_len - ); - } - - // Ensure positions is i64 - let positions = if positions.dtype() != DType::I64 { - positions.to_dtype(DType::I64)? - } else { - positions.clone() - }; - - // Check contiguity - bail if not contiguous (avoid hidden allocations) - if !q.is_contiguous() - || !k.is_contiguous() - || !cos.is_contiguous() - || !sin.is_contiguous() - || !positions.is_contiguous() - { - candle_core::bail!("All tensors (q, k, cos, sin, positions) must be contiguous"); - } - - // Validate dtypes match (except positions which is always i64) - let dtype = q.dtype(); - if k.dtype() != dtype || cos.dtype() != dtype || sin.dtype() != dtype { - candle_core::bail!( - "Q, K, cos, sin must have same dtype, got Q: {:?}, K: {:?}, cos: {:?}, sin: {:?}", - q.dtype(), - k.dtype(), - cos.dtype(), - sin.dtype() - ); - } - - // Get device - let dev = q.device().as_cuda_device()?; - let stream = *dev.cu_stream() as i64; - - // Calculate kernel parameters - let q_bh = (b * q_h) as u32; - let k_bh = (b * k_h) as u32; - let seq_len_u32 = seq_len as u32; - let d_u32 = d as u32; - - // Clone for output - - // Get storage - let q_out_storage = q.storage_and_layout().0; - let k_out_storage = k.storage_and_layout().0; - let cos_storage = cos.storage_and_layout().0; - let sin_storage = sin.storage_and_layout().0; - let pos_storage = positions.storage_and_layout().0; - - let q_out_cuda = match &*q_out_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("Q must be on CUDA"), - }; - let k_out_cuda = match &*k_out_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("K must be on CUDA"), - }; - let cos_cuda = match &*cos_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("cos must be on CUDA"), - }; - let sin_cuda = match &*sin_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("sin must be on CUDA"), - }; - let pos_cuda = match &*pos_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("positions must be on CUDA"), - }; - - // Get positions pointer - let pos_ptr = match &pos_cuda.slice { - CudaStorageSlice::I64(s) => *s.device_ptr() as *const i64, - _ => candle_core::bail!("positions must be I64"), - }; - - match dtype { - DType::F32 => { - let q_ptr = match &q_out_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, - _ => candle_core::bail!("Expected F32"), - }; - let k_ptr = match &k_out_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, - _ => candle_core::bail!("Expected F32"), - }; - let cos_ptr = match &cos_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, - _ => candle_core::bail!("Expected F32"), - }; - let sin_ptr = match &sin_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, - _ => candle_core::bail!("Expected F32"), - }; - - unsafe { - if is_interleaved { - ffi::fused_rope_i_f32( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } else { - ffi::fused_rope_f32( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } - } - } - DType::F16 => { - let q_ptr = match &q_out_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - let k_ptr = match &k_out_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - let cos_ptr = match &cos_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - let sin_ptr = match &sin_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - - unsafe { - if is_interleaved { - ffi::fused_rope_i_f16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } else { - ffi::fused_rope_f16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } - } - } - DType::BF16 => { - let q_ptr = match &q_out_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - let k_ptr = match &k_out_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - let cos_ptr = match &cos_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - let sin_ptr = match &sin_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - - unsafe { - if is_interleaved { - ffi::fused_rope_i_bf16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } else { - ffi::fused_rope_bf16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } - } - } - _ => candle_core::bail!("FusedRope only supports F32, F16, BF16, got {:?}", dtype), - } - + launch_fused_rope(q, k, cos, sin, positions, is_interleaved)?; Ok((q.to_owned(), k.to_owned())) } @@ -305,229 +660,22 @@ impl FusedRope { positions: &Tensor, is_interleaved: bool, ) -> Result<()> { - use candle_core::cuda_backend::cudarc::driver::DevicePtr; - use candle_core::cuda_backend::CudaStorageSlice; - - let (b, q_h, seq_len, d) = q.dims4()?; - let (kb, k_h, k_seq_len, kd) = k.dims4()?; - - if b != kb || seq_len != k_seq_len || d != kd { - candle_core::bail!( - "Q and K batch/seq_len/head_dim must match, got Q: {:?}, K: {:?}", - q.shape(), - k.shape() - ); - } - - // Check contiguity - bail if not contiguous (avoid hidden allocations) - if !q.is_contiguous() || !k.is_contiguous() || !cos.is_contiguous() || !sin.is_contiguous() - { - candle_core::bail!("All tensors (q, k, cos, sin) must be contiguous"); - } - - let positions = if positions.dtype() != DType::I64 { - positions.to_dtype(DType::I64)? - } else { - positions.clone() - }; - if !positions.is_contiguous() { - candle_core::bail!("positions must be contiguous"); - } - - let dtype = q.dtype(); - if k.dtype() != dtype || cos.dtype() != dtype || sin.dtype() != dtype { - candle_core::bail!("Q, K, cos, sin must have same dtype"); - } - - let dev = q.device().as_cuda_device()?; - let stream = *dev.cu_stream() as i64; - - let q_bh = (b * q_h) as u32; - let k_bh = (b * k_h) as u32; - let seq_len_u32 = seq_len as u32; - let d_u32 = d as u32; - - let q_storage = q.storage_and_layout().0; - let k_storage = k.storage_and_layout().0; - let cos_storage = cos.storage_and_layout().0; - let sin_storage = sin.storage_and_layout().0; - let pos_storage = positions.storage_and_layout().0; - - let q_cuda = match &*q_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("Q must be on CUDA"), - }; - let k_cuda = match &*k_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("K must be on CUDA"), - }; - let cos_cuda = match &*cos_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("cos must be on CUDA"), - }; - let sin_cuda = match &*sin_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("sin must be on CUDA"), - }; - let pos_cuda = match &*pos_storage { - candle_core::Storage::Cuda(s) => s, - _ => candle_core::bail!("positions must be on CUDA"), - }; - - let pos_ptr = match &pos_cuda.slice { - CudaStorageSlice::I64(s) => *s.device_ptr() as *const i64, - _ => candle_core::bail!("positions must be I64"), - }; - - match dtype { - DType::F32 => { - let q_ptr = match &q_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, - _ => candle_core::bail!("Expected F32"), - }; - let k_ptr = match &k_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *mut f32, - _ => candle_core::bail!("Expected F32"), - }; - let cos_ptr = match &cos_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, - _ => candle_core::bail!("Expected F32"), - }; - let sin_ptr = match &sin_cuda.slice { - CudaStorageSlice::F32(s) => *s.device_ptr() as *const f32, - _ => candle_core::bail!("Expected F32"), - }; - - unsafe { - if is_interleaved { - ffi::fused_rope_i_f32( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } else { - ffi::fused_rope_f32( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } - } - } - DType::F16 => { - let q_ptr = match &q_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - let k_ptr = match &k_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - let cos_ptr = match &cos_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - let sin_ptr = match &sin_cuda.slice { - CudaStorageSlice::F16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected F16"), - }; - - unsafe { - if is_interleaved { - ffi::fused_rope_i_f16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } else { - ffi::fused_rope_f16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } - } - } - DType::BF16 => { - let q_ptr = match &q_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - let k_ptr = match &k_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *mut core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - let cos_ptr = match &cos_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - let sin_ptr = match &sin_cuda.slice { - CudaStorageSlice::BF16(s) => *s.device_ptr() as *const core::ffi::c_void, - _ => candle_core::bail!("Expected BF16"), - }; - - unsafe { - if is_interleaved { - ffi::fused_rope_i_bf16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } else { - ffi::fused_rope_bf16( - q_ptr, - k_ptr, - cos_ptr, - sin_ptr, - pos_ptr, - q_bh, - k_bh, - seq_len_u32, - d_u32, - stream, - ); - } - } - } - _ => candle_core::bail!("FusedRope only supports F32, F16, BF16, got {:?}", dtype), - } + launch_fused_rope(q, k, cos, sin, positions, is_interleaved) + } - Ok(()) + /// Apply fused rotary embedding in-place to only the leading `rotary_dim` + /// channels of packed token-major Q/K tensors. + #[cfg(feature = "cuda")] + pub fn apply_inplace_partial( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor, + positions: &Tensor, + is_interleaved: bool, + rotary_dim: usize, + ) -> Result<()> { + launch_fused_rope_partial_token_major(q, k, cos, sin, positions, is_interleaved, rotary_dim) } /// Convenience: non-interleaved RoPE diff --git a/src/kernels/Cargo.toml b/src/kernels/Cargo.toml index 28e1775..a4226ef 100644 --- a/src/kernels/Cargo.toml +++ b/src/kernels/Cargo.toml @@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0" [dependencies] [build-dependencies] -cudaforge = "0.1.4" +cudaforge = "0.1.5" anyhow = { version = "1", features = ["backtrace"] } sha2 = "0.10" ureq = { version = "2.10", default-features = true } diff --git a/src/kernels/build.rs b/src/kernels/build.rs index 2621c77..23bae77 100644 --- a/src/kernels/build.rs +++ b/src/kernels/build.rs @@ -27,6 +27,7 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=src/fp8_moe_cutlass.cu"); println!("cargo:rerun-if-changed=src/flashinfer_fp8_qquant.cu"); println!("cargo:rerun-if-changed=src/flashinfer_adapter_fp8.cu"); + println!("cargo:rerun-if-changed=src/flashinfer_bmm_fp8.cu"); println!("cargo:rerun-if-changed=src/flashinfer_moe_adapter.cu"); println!("cargo:rerun-if-changed=src/trtllm/trtllm_batched_gemm_runner.cu"); println!("cargo:rerun-if-changed=src/trtllm/trtllm_fused_moe_runner.cu"); @@ -107,8 +108,45 @@ fn main() -> Result<()> { "csrc/nv_internal/include", "csrc/nv_internal/tensorrt_llm/cutlass_extensions/include", ], + vec![ + "csrc/nv_internal/cpp/common", + "csrc/nv_internal/tensorrt_llm", + ], false, ); + + let flashinfer_root = builder.fetch_git_dependency("flashinfer")?; + let csrc_dir = flashinfer_root.join("csrc"); + let trtllm_dir = csrc_dir.join("nv_internal").join("tensorrt_llm"); + + if matches!(compute_cap, 90 | 100) && trtllm_dir.exists() { + let include_define = format!( + "-DATTENTION_RS_FLASHINFER_TRTLLM_INCLUDE_DIR=\\\"{}\\\"", + trtllm_dir.display() + ); + builder = builder + .arg("-DATTENTION_RS_USE_FLASHINFER_BLOCKSCALE") + .arg("-DCOMPILE_HOPPER_TMA_GEMMS") + .arg("-DENABLE_FP8_BLOCK_SCALE") + .arg(&include_define) + .include_path(csrc_dir.join("nv_internal/tensorrt_llm/kernels/cutlass_kernels/include")) + .include_path(csrc_dir.join("nv_internal/tensorrt_llm/kernels/cutlass_kernels")) + .source_files(vec![ + csrc_dir.join( + "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", + ), + csrc_dir.join("nv_internal/cpp/common/envUtils.cpp"), + csrc_dir.join("nv_internal/cpp/common/logger.cpp"), + csrc_dir.join("nv_internal/cpp/common/stringUtils.cpp"), + csrc_dir.join("nv_internal/cpp/common/tllmException.cpp"), + csrc_dir.join("nv_internal/cpp/common/memoryUtils.cu"), + ]); + } else if matches!(compute_cap, 90 | 100) { + println!( + "cargo:warning=flashinfer TensorRT-LLM sources not found at {}, skipping blockscale fp8 wrapper", + trtllm_dir.display() + ); + } } // Target handling diff --git a/src/kernels/src/ffi.rs b/src/kernels/src/ffi.rs index d528675..c05b6ad 100644 --- a/src/kernels/src/ffi.rs +++ b/src/kernels/src/ffi.rs @@ -777,6 +777,169 @@ extern "C" { stream: i64, ); + // Token-major variants for packed [num_tokens, num_heads, head_dim] tensors. + pub fn fused_rope_tok_major_f32( + q: *mut f32, + k: *mut f32, + cos: *const f32, + sin: *const f32, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + stream: i64, + ); + + pub fn fused_rope_tok_major_f16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + stream: i64, + ); + + pub fn fused_rope_tok_major_bf16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + stream: i64, + ); + + pub fn fused_rope_i_tok_major_f32( + q: *mut f32, + k: *mut f32, + cos: *const f32, + sin: *const f32, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + stream: i64, + ); + + pub fn fused_rope_i_tok_major_f16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + stream: i64, + ); + + pub fn fused_rope_i_tok_major_bf16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + d: u32, + stream: i64, + ); + + pub fn fused_rope_partial_tok_major_f32( + q: *mut f32, + k: *mut f32, + cos: *const f32, + sin: *const f32, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + rotary_d: u32, + full_d: u32, + stream: i64, + ); + + pub fn fused_rope_partial_tok_major_f16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + rotary_d: u32, + full_d: u32, + stream: i64, + ); + + pub fn fused_rope_partial_tok_major_bf16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + rotary_d: u32, + full_d: u32, + stream: i64, + ); + + pub fn fused_rope_i_partial_tok_major_f32( + q: *mut f32, + k: *mut f32, + cos: *const f32, + sin: *const f32, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + rotary_d: u32, + full_d: u32, + stream: i64, + ); + + pub fn fused_rope_i_partial_tok_major_f16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + rotary_d: u32, + full_d: u32, + stream: i64, + ); + + pub fn fused_rope_i_partial_tok_major_bf16( + q: *mut c_void, + k: *mut c_void, + cos: *const c_void, + sin: *const c_void, + positions: *const i64, + num_tokens: u32, + q_heads: u32, + k_heads: u32, + rotary_d: u32, + full_d: u32, + stream: i64, + ); + pub fn fp8_matmul_f16( input: *const c_void, // [M, K] weight: *const u8, // [N, K] @@ -837,6 +1000,41 @@ extern "C" { stream: i64, ); + #[cfg(feature = "flashinfer")] + pub fn flashinfer_fp8_blockscale_workspace_size_bf16(m: c_int, n: c_int, k: c_int) -> usize; + + #[cfg(feature = "flashinfer")] + pub fn flashinfer_fp8_blockscale_bf16( + input: *const c_void, + weight: *const c_void, + weight_scale: *const f32, + output: *mut c_void, + m: c_int, + n: c_int, + k: c_int, + workspace: *mut c_void, + workspace_size: usize, + stream: i64, + ) -> c_int; + + #[cfg(feature = "flashinfer")] + pub fn flashinfer_fp8_blockscale_workspace_size_fp8(m: c_int, n: c_int, k: c_int) -> usize; + + #[cfg(feature = "flashinfer")] + pub fn flashinfer_fp8_blockscale_fp8( + input: *const c_void, + input_scale: *const f32, + weight: *const c_void, + weight_scale: *const f32, + output: *mut c_void, + m: c_int, + n: c_int, + k: c_int, + workspace: *mut c_void, + workspace_size: usize, + stream: i64, + ) -> c_int; + pub fn moe_fp8_calculate_expert_offsets( expert_ids: *const i32, expert_counts: *mut i32, diff --git a/src/kernels/src/flashinfer_bmm_fp8.cu b/src/kernels/src/flashinfer_bmm_fp8.cu new file mode 100644 index 0000000..82b6f82 --- /dev/null +++ b/src/kernels/src/flashinfer_bmm_fp8.cu @@ -0,0 +1,159 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#if defined(USE_FLASHINFER) && defined(ATTENTION_RS_USE_FLASHINFER_BLOCKSCALE) && \ + defined(FLASHINFER_ENABLE_FP8_E4M3) +#include "tensorrt_llm/deep_gemm/compiler.cuh" +#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" + +namespace { + +using Runner = tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunner< + __nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; +using RunnerFp8 = tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunner< + __nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; + +thread_local std::unique_ptr g_runner; +thread_local std::unique_ptr g_runner_fp8; +thread_local bool g_runner_initialized = false; + +Runner& get_runner() { + if (!g_runner) { + g_runner = std::make_unique(); + } + if (!g_runner_initialized) { + deep_gemm::jit::Compiler::setIncludeDirs( + {std::filesystem::path(ATTENTION_RS_FLASHINFER_TRTLLM_INCLUDE_DIR)}); + g_runner_initialized = true; + } + return *g_runner; +} + +RunnerFp8& get_runner_fp8() { + if (!g_runner_fp8) { + g_runner_fp8 = std::make_unique(); + } + if (!g_runner_initialized) { + deep_gemm::jit::Compiler::setIncludeDirs( + {std::filesystem::path(ATTENTION_RS_FLASHINFER_TRTLLM_INCLUDE_DIR)}); + g_runner_initialized = true; + } + return *g_runner_fp8; +} + +} // namespace + +extern "C" size_t flashinfer_fp8_blockscale_workspace_size_bf16(int m, int n, int k) { + try { + if (m <= 0 || n <= 0 || k <= 0) { + return 0; + } + return get_runner().getWorkspaceSize(static_cast(m), static_cast(n), + static_cast(k), 1, 1); + } catch (std::exception const& e) { + std::fprintf(stderr, "flashinfer_fp8_blockscale_workspace_size_bf16 failed: %s\n", e.what()); + return 0; + } +} + +extern "C" int flashinfer_fp8_blockscale_bf16(const void* input, const void* weight, + const float* weight_scale, void* output, int m, + int n, int k, void* workspace, + size_t workspace_size, int64_t stream_) { + try { + if (input == nullptr || weight == nullptr || weight_scale == nullptr || output == nullptr) { + return 1; + } + + auto& runner = get_runner(); + size_t required = + runner.getWorkspaceSize(static_cast(m), static_cast(n), static_cast(k), 1, 1); + if (workspace == nullptr || workspace_size < required) { + return 2; + } + + runner.configureWorkspace(reinterpret_cast(workspace)); + runner.gemm(output, input, weight, m, n, k, reinterpret_cast(stream_), nullptr, + weight_scale); + return 0; + } catch (std::exception const& e) { + std::fprintf(stderr, "flashinfer_fp8_blockscale_bf16 failed: %s\n", e.what()); + return -1; + } catch (...) { + std::fprintf(stderr, "flashinfer_fp8_blockscale_bf16 failed with unknown exception\n"); + return -2; + } +} + +extern "C" size_t flashinfer_fp8_blockscale_workspace_size_fp8(int m, int n, int k) { + try { + if (m <= 0 || n <= 0 || k <= 0) { + return 0; + } + return get_runner_fp8().getWorkspaceSize(static_cast(m), static_cast(n), + static_cast(k), 1, 1); + } catch (std::exception const& e) { + std::fprintf(stderr, "flashinfer_fp8_blockscale_workspace_size_fp8 failed: %s\n", e.what()); + return 0; + } +} + +extern "C" int flashinfer_fp8_blockscale_fp8(const void* input, const float* input_scale, + const void* weight, const float* weight_scale, + void* output, int m, int n, int k, void* workspace, + size_t workspace_size, int64_t stream_) { + try { + if (input == nullptr || input_scale == nullptr || weight == nullptr || weight_scale == nullptr || + output == nullptr) { + return 1; + } + + auto& runner = get_runner_fp8(); + size_t required = + runner.getWorkspaceSize(static_cast(m), static_cast(n), static_cast(k), 1, 1); + if (required > 0 && (workspace == nullptr || workspace_size < required)) { + return 2; + } + if (required > 0) { + runner.configureWorkspace(reinterpret_cast(workspace)); + } + + runner.gemm(reinterpret_cast<__nv_fp8_e4m3 const*>(input), k, + reinterpret_cast<__nv_fp8_e4m3 const*>(weight), k, + reinterpret_cast<__nv_bfloat16*>(output), n, m, n, k, input_scale, weight_scale, + reinterpret_cast(stream_)); + return 0; + } catch (std::exception const& e) { + std::fprintf(stderr, "flashinfer_fp8_blockscale_fp8 failed: %s\n", e.what()); + return -1; + } catch (...) { + std::fprintf(stderr, "flashinfer_fp8_blockscale_fp8 failed with unknown exception\n"); + return -2; + } +} + +#else + +extern "C" size_t flashinfer_fp8_blockscale_workspace_size_bf16(int, int, int) { return 0; } + +extern "C" int flashinfer_fp8_blockscale_bf16(const void*, const void*, const float*, void*, int, + int, int, void*, size_t, int64_t) { + return -1; +} + +extern "C" size_t flashinfer_fp8_blockscale_workspace_size_fp8(int, int, int) { return 0; } + +extern "C" int flashinfer_fp8_blockscale_fp8(const void*, const float*, const void*, const float*, + void*, int, int, int, void*, size_t, int64_t) { + return -1; +} + +#endif diff --git a/src/kernels/src/fused_rope.cu b/src/kernels/src/fused_rope.cu index bfbbed9..3e6cd45 100644 --- a/src/kernels/src/fused_rope.cu +++ b/src/kernels/src/fused_rope.cu @@ -184,6 +184,522 @@ fused_rope_i_bf16_kernel( } #endif +// ============================================================================ +// Token-major RoPE with Position Selection +// ============================================================================ + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_i_tok_major_f32_kernel( + float2* __restrict__ q, + float2* __restrict__ k, + const float* __restrict__ cos, + const float* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t half_d +) { + const uint32_t q_pairs = num_tokens * q_heads * half_d; + const uint32_t k_pairs = num_tokens * k_heads * half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_d); + const uint32_t d_idx = local_idx % half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_d + d_idx; + const float c = cos[cs_idx]; + const float s = sin[cs_idx]; + + float2* ptr = is_q ? q : k; + const uint32_t pair_idx = (token_idx * heads + head_idx) * half_d + d_idx; + const float2 v = ptr[pair_idx]; + + float2 result; + result.x = v.x * c - v.y * s; + result.y = v.x * s + v.y * c; + ptr[pair_idx] = result; + } +} + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_i_tok_major_f16_kernel( + __half2* __restrict__ q, + __half2* __restrict__ k, + const __half* __restrict__ cos, + const __half* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t half_d +) { + const uint32_t q_pairs = num_tokens * q_heads * half_d; + const uint32_t k_pairs = num_tokens * k_heads * half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_d); + const uint32_t d_idx = local_idx % half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_d + d_idx; + const float c = __half2float(cos[cs_idx]); + const float s = __half2float(sin[cs_idx]); + + __half2* ptr = is_q ? q : k; + const uint32_t pair_idx = (token_idx * heads + head_idx) * half_d + d_idx; + const __half2 v = ptr[pair_idx]; + + __half2 result; + result.x = __float2half(__half2float(v.x) * c - __half2float(v.y) * s); + result.y = __float2half(__half2float(v.x) * s + __half2float(v.y) * c); + ptr[pair_idx] = result; + } +} + +#ifndef NO_BF16_KERNEL +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_i_tok_major_bf16_kernel( + __nv_bfloat162* __restrict__ q, + __nv_bfloat162* __restrict__ k, + const __nv_bfloat16* __restrict__ cos, + const __nv_bfloat16* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t half_d +) { + const uint32_t q_pairs = num_tokens * q_heads * half_d; + const uint32_t k_pairs = num_tokens * k_heads * half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_d); + const uint32_t d_idx = local_idx % half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_d + d_idx; + const __nv_bfloat16 c = cos[cs_idx]; + const __nv_bfloat16 s = sin[cs_idx]; + + __nv_bfloat162* ptr = is_q ? q : k; + const uint32_t pair_idx = (token_idx * heads + head_idx) * half_d + d_idx; + const __nv_bfloat162 v = ptr[pair_idx]; + + __nv_bfloat162 result; + result.x = __hsub(__hmul(v.x, c), __hmul(v.y, s)); + result.y = __hadd(__hmul(v.x, s), __hmul(v.y, c)); + ptr[pair_idx] = result; + } +} +#endif + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_tok_major_f32_kernel( + float* __restrict__ q, + float* __restrict__ k, + const float* __restrict__ cos, + const float* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t d +) { + const uint32_t half_d = d / 2; + const uint32_t q_pairs = num_tokens * q_heads * half_d; + const uint32_t k_pairs = num_tokens * k_heads * half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_d); + const uint32_t d_idx = local_idx % half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_d + d_idx; + const float c = cos[cs_idx]; + const float s = sin[cs_idx]; + + float* ptr = is_q ? q : k; + const uint32_t base = (token_idx * heads + head_idx) * d + d_idx; + const float x = ptr[base]; + const float y = ptr[base + half_d]; + ptr[base] = x * c - y * s; + ptr[base + half_d] = y * c + x * s; + } +} + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_tok_major_f16_kernel( + __half* __restrict__ q, + __half* __restrict__ k, + const __half* __restrict__ cos, + const __half* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t d +) { + const uint32_t half_d = d / 2; + const uint32_t q_pairs = num_tokens * q_heads * half_d; + const uint32_t k_pairs = num_tokens * k_heads * half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_d); + const uint32_t d_idx = local_idx % half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_d + d_idx; + const float c = __half2float(cos[cs_idx]); + const float s = __half2float(sin[cs_idx]); + + __half* ptr = is_q ? q : k; + const uint32_t base = (token_idx * heads + head_idx) * d + d_idx; + const float x = __half2float(ptr[base]); + const float y = __half2float(ptr[base + half_d]); + ptr[base] = __float2half(x * c - y * s); + ptr[base + half_d] = __float2half(y * c + x * s); + } +} + +#ifndef NO_BF16_KERNEL +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_tok_major_bf16_kernel( + __nv_bfloat16* __restrict__ q, + __nv_bfloat16* __restrict__ k, + const __nv_bfloat16* __restrict__ cos, + const __nv_bfloat16* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t d +) { + const uint32_t half_d = d / 2; + const uint32_t q_pairs = num_tokens * q_heads * half_d; + const uint32_t k_pairs = num_tokens * k_heads * half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_d); + const uint32_t d_idx = local_idx % half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_d + d_idx; + const __nv_bfloat16 c = cos[cs_idx]; + const __nv_bfloat16 s = sin[cs_idx]; + + __nv_bfloat16* ptr = is_q ? q : k; + const uint32_t base = (token_idx * heads + head_idx) * d + d_idx; + const __nv_bfloat16 x = ptr[base]; + const __nv_bfloat16 y = ptr[base + half_d]; + ptr[base] = __hsub(__hmul(x, c), __hmul(y, s)); + ptr[base + half_d] = __hadd(__hmul(y, c), __hmul(x, s)); + } +} +#endif + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_partial_tok_major_f32_kernel( + float* __restrict__ q, + float* __restrict__ k, + const float* __restrict__ cos, + const float* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t rotary_d, + const uint32_t full_d +) { + const uint32_t half_rotary_d = rotary_d / 2; + const uint32_t q_pairs = num_tokens * q_heads * half_rotary_d; + const uint32_t k_pairs = num_tokens * k_heads * half_rotary_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_rotary_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_rotary_d); + const uint32_t d_idx = local_idx % half_rotary_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_rotary_d + d_idx; + const float c = cos[cs_idx]; + const float s = sin[cs_idx]; + + float* ptr = is_q ? q : k; + const uint32_t base = (token_idx * heads + head_idx) * full_d + d_idx; + const float x = ptr[base]; + const float y = ptr[base + half_rotary_d]; + ptr[base] = x * c - y * s; + ptr[base + half_rotary_d] = y * c + x * s; + } +} + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_partial_tok_major_f16_kernel( + __half* __restrict__ q, + __half* __restrict__ k, + const __half* __restrict__ cos, + const __half* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t rotary_d, + const uint32_t full_d +) { + const uint32_t half_rotary_d = rotary_d / 2; + const uint32_t q_pairs = num_tokens * q_heads * half_rotary_d; + const uint32_t k_pairs = num_tokens * k_heads * half_rotary_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_rotary_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_rotary_d); + const uint32_t d_idx = local_idx % half_rotary_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_rotary_d + d_idx; + const float c = __half2float(cos[cs_idx]); + const float s = __half2float(sin[cs_idx]); + + __half* ptr = is_q ? q : k; + const uint32_t base = (token_idx * heads + head_idx) * full_d + d_idx; + const float x = __half2float(ptr[base]); + const float y = __half2float(ptr[base + half_rotary_d]); + ptr[base] = __float2half(x * c - y * s); + ptr[base + half_rotary_d] = __float2half(y * c + x * s); + } +} + +#ifndef NO_BF16_KERNEL +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_partial_tok_major_bf16_kernel( + __nv_bfloat16* __restrict__ q, + __nv_bfloat16* __restrict__ k, + const __nv_bfloat16* __restrict__ cos, + const __nv_bfloat16* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t rotary_d, + const uint32_t full_d +) { + const uint32_t half_rotary_d = rotary_d / 2; + const uint32_t q_pairs = num_tokens * q_heads * half_rotary_d; + const uint32_t k_pairs = num_tokens * k_heads * half_rotary_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / half_rotary_d) % heads; + const uint32_t token_idx = local_idx / (heads * half_rotary_d); + const uint32_t d_idx = local_idx % half_rotary_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * half_rotary_d + d_idx; + const __nv_bfloat16 c = cos[cs_idx]; + const __nv_bfloat16 s = sin[cs_idx]; + + __nv_bfloat16* ptr = is_q ? q : k; + const uint32_t base = (token_idx * heads + head_idx) * full_d + d_idx; + const __nv_bfloat16 x = ptr[base]; + const __nv_bfloat16 y = ptr[base + half_rotary_d]; + ptr[base] = __hsub(__hmul(x, c), __hmul(y, s)); + ptr[base + half_rotary_d] = __hadd(__hmul(y, c), __hmul(x, s)); + } +} +#endif + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_i_partial_tok_major_f32_kernel( + float2* __restrict__ q, + float2* __restrict__ k, + const float* __restrict__ cos, + const float* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t rotary_half_d, + const uint32_t full_half_d +) { + const uint32_t q_pairs = num_tokens * q_heads * rotary_half_d; + const uint32_t k_pairs = num_tokens * k_heads * rotary_half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / rotary_half_d) % heads; + const uint32_t token_idx = local_idx / (heads * rotary_half_d); + const uint32_t d_idx = local_idx % rotary_half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * rotary_half_d + d_idx; + const float c = cos[cs_idx]; + const float s = sin[cs_idx]; + + float2* ptr = is_q ? q : k; + const uint32_t pair_idx = (token_idx * heads + head_idx) * full_half_d + d_idx; + const float2 v = ptr[pair_idx]; + + float2 result; + result.x = v.x * c - v.y * s; + result.y = v.x * s + v.y * c; + ptr[pair_idx] = result; + } +} + +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_i_partial_tok_major_f16_kernel( + __half2* __restrict__ q, + __half2* __restrict__ k, + const __half* __restrict__ cos, + const __half* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t rotary_half_d, + const uint32_t full_half_d +) { + const uint32_t q_pairs = num_tokens * q_heads * rotary_half_d; + const uint32_t k_pairs = num_tokens * k_heads * rotary_half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / rotary_half_d) % heads; + const uint32_t token_idx = local_idx / (heads * rotary_half_d); + const uint32_t d_idx = local_idx % rotary_half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * rotary_half_d + d_idx; + const float c = __half2float(cos[cs_idx]); + const float s = __half2float(sin[cs_idx]); + + __half2* ptr = is_q ? q : k; + const uint32_t pair_idx = (token_idx * heads + head_idx) * full_half_d + d_idx; + const __half2 v = ptr[pair_idx]; + + __half2 result; + result.x = __float2half(__half2float(v.x) * c - __half2float(v.y) * s); + result.y = __float2half(__half2float(v.x) * s + __half2float(v.y) * c); + ptr[pair_idx] = result; + } +} + +#ifndef NO_BF16_KERNEL +__global__ void __launch_bounds__(BLOCK_SIZE) +fused_rope_i_partial_tok_major_bf16_kernel( + __nv_bfloat162* __restrict__ q, + __nv_bfloat162* __restrict__ k, + const __nv_bfloat16* __restrict__ cos, + const __nv_bfloat16* __restrict__ sin, + const int64_t* __restrict__ positions, + const uint32_t num_tokens, + const uint32_t q_heads, + const uint32_t k_heads, + const uint32_t rotary_half_d, + const uint32_t full_half_d +) { + const uint32_t q_pairs = num_tokens * q_heads * rotary_half_d; + const uint32_t k_pairs = num_tokens * k_heads * rotary_half_d; + const uint32_t total_pairs = q_pairs + k_pairs; + + for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total_pairs; + idx += gridDim.x * blockDim.x) { + const bool is_q = idx < q_pairs; + const uint32_t local_idx = is_q ? idx : (idx - q_pairs); + const uint32_t heads = is_q ? q_heads : k_heads; + const uint32_t head_idx = (local_idx / rotary_half_d) % heads; + const uint32_t token_idx = local_idx / (heads * rotary_half_d); + const uint32_t d_idx = local_idx % rotary_half_d; + + const int64_t pos = positions[token_idx]; + const uint32_t cs_idx = pos * rotary_half_d + d_idx; + const __nv_bfloat16 c = cos[cs_idx]; + const __nv_bfloat16 s = sin[cs_idx]; + + __nv_bfloat162* ptr = is_q ? q : k; + const uint32_t pair_idx = (token_idx * heads + head_idx) * full_half_d + d_idx; + const __nv_bfloat162 v = ptr[pair_idx]; + + __nv_bfloat162 result; + result.x = __hsub(__hmul(v.x, c), __hmul(v.y, s)); + result.y = __hadd(__hmul(v.x, s), __hmul(v.y, c)); + ptr[pair_idx] = result; + } +} +#endif + // ============================================================================ // Non-Interleaved RoPE with Position Selection // ============================================================================ @@ -481,3 +997,231 @@ extern "C" void fused_rope_i_bf16( ); #endif } + +extern "C" void fused_rope_tok_major_f32( + float* q, float* k, + const float* cos, const float* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_d = d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_d; + dim3 grid = get_optimal_grid(total); + fused_rope_tok_major_f32_kernel<<>>( + q, k, cos, sin, positions, num_tokens, q_heads, k_heads, d + ); +} + +extern "C" void fused_rope_tok_major_f16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_d = d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_d; + dim3 grid = get_optimal_grid(total); + fused_rope_tok_major_f16_kernel<<>>( + (__half*)q, (__half*)k, (const __half*)cos, (const __half*)sin, + positions, num_tokens, q_heads, k_heads, d + ); +} + +extern "C" void fused_rope_tok_major_bf16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_d = d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_d; +#ifndef NO_BF16_KERNEL + dim3 grid = get_optimal_grid(total); + fused_rope_tok_major_bf16_kernel<<>>( + (__nv_bfloat16*)q, (__nv_bfloat16*)k, + (const __nv_bfloat16*)cos, (const __nv_bfloat16*)sin, + positions, num_tokens, q_heads, k_heads, d + ); +#endif +} + +extern "C" void fused_rope_i_tok_major_f32( + float* q, float* k, + const float* cos, const float* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_d = d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_d; + dim3 grid = get_optimal_grid(total); + fused_rope_i_tok_major_f32_kernel<<>>( + (float2*)q, (float2*)k, cos, sin, positions, num_tokens, q_heads, k_heads, half_d + ); +} + +extern "C" void fused_rope_i_tok_major_f16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_d = d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_d; + dim3 grid = get_optimal_grid(total); + fused_rope_i_tok_major_f16_kernel<<>>( + (__half2*)q, (__half2*)k, (const __half*)cos, (const __half*)sin, + positions, num_tokens, q_heads, k_heads, half_d + ); +} + +extern "C" void fused_rope_i_tok_major_bf16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_d = d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_d; +#ifndef NO_BF16_KERNEL + dim3 grid = get_optimal_grid(total); + fused_rope_i_tok_major_bf16_kernel<<>>( + (__nv_bfloat162*)q, (__nv_bfloat162*)k, + (const __nv_bfloat16*)cos, (const __nv_bfloat16*)sin, + positions, num_tokens, q_heads, k_heads, half_d + ); +#endif +} + +extern "C" void fused_rope_partial_tok_major_f32( + float* q, float* k, + const float* cos, const float* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t rotary_d, uint32_t full_d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_rotary_d = rotary_d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_rotary_d; + dim3 grid = get_optimal_grid(total); + fused_rope_partial_tok_major_f32_kernel<<>>( + q, k, cos, sin, positions, num_tokens, q_heads, k_heads, rotary_d, full_d + ); +} + +extern "C" void fused_rope_partial_tok_major_f16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t rotary_d, uint32_t full_d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_rotary_d = rotary_d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_rotary_d; + dim3 grid = get_optimal_grid(total); + fused_rope_partial_tok_major_f16_kernel<<>>( + (__half*)q, (__half*)k, (const __half*)cos, (const __half*)sin, + positions, num_tokens, q_heads, k_heads, rotary_d, full_d + ); +} + +extern "C" void fused_rope_partial_tok_major_bf16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t rotary_d, uint32_t full_d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t half_rotary_d = rotary_d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * half_rotary_d; +#ifndef NO_BF16_KERNEL + dim3 grid = get_optimal_grid(total); + fused_rope_partial_tok_major_bf16_kernel<<>>( + (__nv_bfloat16*)q, (__nv_bfloat16*)k, + (const __nv_bfloat16*)cos, (const __nv_bfloat16*)sin, + positions, num_tokens, q_heads, k_heads, rotary_d, full_d + ); +#endif +} + +extern "C" void fused_rope_i_partial_tok_major_f32( + float* q, float* k, + const float* cos, const float* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t rotary_d, uint32_t full_d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t rotary_half_d = rotary_d / 2; + const uint32_t full_half_d = full_d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * rotary_half_d; + dim3 grid = get_optimal_grid(total); + fused_rope_i_partial_tok_major_f32_kernel<<>>( + (float2*)q, (float2*)k, cos, sin, positions, + num_tokens, q_heads, k_heads, rotary_half_d, full_half_d + ); +} + +extern "C" void fused_rope_i_partial_tok_major_f16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t rotary_d, uint32_t full_d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t rotary_half_d = rotary_d / 2; + const uint32_t full_half_d = full_d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * rotary_half_d; + dim3 grid = get_optimal_grid(total); + fused_rope_i_partial_tok_major_f16_kernel<<>>( + (__half2*)q, (__half2*)k, (const __half*)cos, (const __half*)sin, + positions, num_tokens, q_heads, k_heads, rotary_half_d, full_half_d + ); +} + +extern "C" void fused_rope_i_partial_tok_major_bf16( + void* q, void* k, + const void* cos, const void* sin, + const int64_t* positions, + uint32_t num_tokens, uint32_t q_heads, uint32_t k_heads, + uint32_t rotary_d, uint32_t full_d, + int64_t stream_ptr +) { + cudaStream_t stream = (cudaStream_t)stream_ptr; + const uint32_t rotary_half_d = rotary_d / 2; + const uint32_t full_half_d = full_d / 2; + const uint32_t total = num_tokens * (q_heads + k_heads) * rotary_half_d; +#ifndef NO_BF16_KERNEL + dim3 grid = get_optimal_grid(total); + fused_rope_i_partial_tok_major_bf16_kernel<<>>( + (__nv_bfloat162*)q, (__nv_bfloat162*)k, + (const __nv_bfloat16*)cos, (const __nv_bfloat16*)sin, + positions, num_tokens, q_heads, k_heads, rotary_half_d, full_half_d + ); +#endif +} diff --git a/src/lib.rs b/src/lib.rs index 95329ae..10f2c40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,131 @@ pub struct PagedAttention { } impl PagedAttention { + fn batch_major_qkv( + query: &Tensor, + key: &Tensor, + value: &Tensor, + ) -> Result<(Tensor, Tensor, Tensor, usize, usize, usize)> { + match (query.dims().len(), key.dims().len(), value.dims().len()) { + (4, 4, 4) => { + let (_, attention_heads, seq_len, head_size) = query.shape().dims4()?; + let (_, key_value_heads, key_seq_len, key_head_size) = key.shape().dims4()?; + let (_, value_heads, value_seq_len, value_head_size) = value.shape().dims4()?; + if key_seq_len != seq_len + || value_seq_len != seq_len + || key_head_size != head_size + || value_head_size != head_size + || value_heads != key_value_heads + { + candle_core::bail!( + "Q/K/V layout mismatch, got Q {:?}, K {:?}, V {:?}", + query.shape(), + key.shape(), + value.shape() + ); + } + Ok(( + query.clone(), + key.clone(), + value.clone(), + attention_heads, + key_value_heads, + head_size, + )) + } + (3, 3, 3) => { + let (seq_len, attention_heads, head_size) = query.shape().dims3()?; + let (key_seq_len, key_value_heads, key_head_size) = key.shape().dims3()?; + let (value_seq_len, value_heads, value_head_size) = value.shape().dims3()?; + if key_seq_len != seq_len + || value_seq_len != seq_len + || key_head_size != head_size + || value_head_size != head_size + || value_heads != key_value_heads + { + candle_core::bail!( + "packed Q/K/V layout mismatch, got Q {:?}, K {:?}, V {:?}", + query.shape(), + key.shape(), + value.shape() + ); + } + Ok(( + query.transpose(0, 1)?.unsqueeze(0)?, + key.transpose(0, 1)?.unsqueeze(0)?, + value.transpose(0, 1)?.unsqueeze(0)?, + attention_heads, + key_value_heads, + head_size, + )) + } + _ => candle_core::bail!( + "paged attention expects 3D packed or 4D batch-major Q/K/V, got Q {:?}, K {:?}, V {:?}", + query.shape(), + key.shape(), + value.shape() + ), + } + } + + #[cfg(any(feature = "flash-attn", feature = "flashinfer"))] + fn packed_qkv( + query: &Tensor, + key: &Tensor, + value: &Tensor, + ) -> Result<(Tensor, Tensor, Tensor, usize, usize, usize)> { + match (query.dims().len(), key.dims().len(), value.dims().len()) { + (4, 4, 4) => { + let (_, attention_heads, _, head_size) = query.shape().dims4()?; + let (_, key_value_heads, _, _) = key.shape().dims4()?; + let query = query + .transpose(1, 2)? + .reshape(((), attention_heads, head_size))?; + let key = key + .transpose(1, 2)? + .reshape(((), key_value_heads, head_size))?; + let value = value + .transpose(1, 2)? + .reshape(((), key_value_heads, head_size))?; + Ok((query, key, value, attention_heads, key_value_heads, head_size)) + } + (3, 3, 3) => { + let (_, attention_heads, head_size) = query.shape().dims3()?; + let (_, key_value_heads, key_head_size) = key.shape().dims3()?; + let (_, value_heads, value_head_size) = value.shape().dims3()?; + if key_head_size != head_size || value_head_size != head_size { + candle_core::bail!( + "packed Q/K/V head_dim mismatch, got Q {:?}, K {:?}, V {:?}", + query.shape(), + key.shape(), + value.shape() + ); + } + if value_heads != key_value_heads { + candle_core::bail!( + "packed K/V head count mismatch, got K {:?}, V {:?}", + key.shape(), + value.shape() + ); + } + Ok(( + query.clone(), + key.clone(), + value.clone(), + attention_heads, + key_value_heads, + head_size, + )) + } + _ => candle_core::bail!( + "flash attention expects 3D packed or 4D batch-major Q/K/V, got Q {:?}, K {:?}, V {:?}", + query.shape(), + key.shape(), + value.shape() + ), + } + } + fn maybe_update_kv_scales(&self, key: &Tensor, value: &Tensor) -> Result<()> { if let (Some(k_scale), Some(v_scale)) = (&self.k_scale, &self.v_scale) { if self.kv_updated_times.load(Ordering::Relaxed) < KV_SCALE_UPDATE_ITERATION { @@ -142,8 +267,8 @@ impl PagedAttention { input_metadata: &InputMetadata, softcapping: Option, ) -> Result { - let (_, attention_heads, _, head_size) = query.shape().dims4()?; - let (_, key_value_heads, _, _) = key.shape().dims4()?; + let (query, key, value, attention_heads, key_value_heads, head_size) = + Self::batch_major_qkv(query, key, value)?; fn repeat_kv(x: Tensor, n_rep: usize) -> Result { if n_rep == 1 { Ok(x) @@ -274,18 +399,9 @@ impl PagedAttention { input_metadata: &InputMetadata, softcapping: Option, ) -> Result { - let (_, attention_heads, _, head_size) = query.shape().dims4()?; - let (_, key_value_heads, _, _) = key.shape().dims4()?; + let (query, key, value, _attention_heads, _key_value_heads, _head_size) = + Self::packed_qkv(query, key, value)?; let slot_mapping = input_metadata.slot_mapping.flatten_all()?; - let query = query - .transpose(1, 2)? - .reshape(((), attention_heads, head_size))?; - let key = key - .transpose(1, 2)? - .reshape(((), key_value_heads, head_size))?; - let value = value - .transpose(1, 2)? - .reshape(((), key_value_heads, head_size))?; let softcap = Some(softcapping.unwrap_or(0.0f64) as f32); let window_size_right = self.sliding_window.map(|_| 0); @@ -376,8 +492,8 @@ impl PagedAttention { ) -> Result { #[cfg(feature = "flashinfer")] if let Some(fm) = input_metadata.flashinfer_metadata.as_ref() { - let (_, attention_heads, _, head_size) = query.shape().dims4()?; - let (_, key_value_heads, _, _) = key.shape().dims4()?; + let (query, key, value, attention_heads, key_value_heads, head_size) = + Self::packed_qkv(query, key, value)?; let group_size = attention_heads / key_value_heads; let flashinfer_prefill_group_supported = matches!(group_size, 1 | 2 | 3 | 4 | 6 | 8 | 16 | 32 | 64); @@ -398,17 +514,6 @@ impl PagedAttention { } } - let query = query - .transpose(1, 2)? - .reshape(((), attention_heads, head_size))?; - let key = key - .transpose(1, 2)? - .reshape(((), key_value_heads, head_size))?; - - let value = value - .transpose(1, 2)? - .reshape(((), key_value_heads, head_size))?; - self.maybe_update_kv_scales(&key, &value)?; if let (Some(kc), Some(vc)) = (key_cache.as_ref(), value_cache.as_ref()) { @@ -542,8 +647,8 @@ impl PagedAttention { // The following for paged attention let slot_mapping = input_metadata.slot_mapping.flatten_all()?; - let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?; - let (_, key_value_heads, _, _) = key.shape().dims4()?; + let (query, key, value, attention_heads, key_value_heads, head_size) = + Self::batch_major_qkv(query, key, value)?; // Write KvCache for SDP + Paged Attention let key = key diff --git a/src/metal-kernels/Cargo.toml b/src/metal-kernels/Cargo.toml index 3f14159..a3b4292 100644 --- a/src/metal-kernels/Cargo.toml +++ b/src/metal-kernels/Cargo.toml @@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0" metal = { version = "0.27.0", features = ["mps"] } thiserror = "1" once_cell = "1.20.2" -candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } +candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038" } [build-dependencies] anyhow = { version = "1", features = ["backtrace"] } From 9be89ac2f16d3f8abc52f955c992761ebeca3486 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 10 Mar 2026 22:21:13 +0800 Subject: [PATCH 5/7] Update fused rope for Metal --- src/fused_rope.rs | 278 ++++++++++++++++--------- src/metal-kernels/src/fused_rope.metal | 256 +++++++++++++++++------ src/metal-kernels/src/lib.rs | 12 +- 3 files changed, 379 insertions(+), 167 deletions(-) diff --git a/src/fused_rope.rs b/src/fused_rope.rs index bb42d89..f0aa52e 100644 --- a/src/fused_rope.rs +++ b/src/fused_rope.rs @@ -12,7 +12,6 @@ use candle_core::{DType, Result, Tensor}; #[cfg(feature = "cuda")] use kernels::ffi; -#[cfg(feature = "cuda")] #[derive(Clone, Copy)] enum RopeLayout { BatchMajor { @@ -29,7 +28,6 @@ enum RopeLayout { }, } -#[cfg(feature = "cuda")] impl RopeLayout { fn positions_len(self) -> usize { match self { @@ -37,9 +35,40 @@ impl RopeLayout { Self::TokenMajor { num_tokens, .. } => num_tokens as usize, } } + + fn q_bh(self) -> u32 { + match self { + Self::BatchMajor { q_bh, .. } => q_bh, + Self::TokenMajor { q_heads, .. } => q_heads, + } + } + + fn k_bh(self) -> u32 { + match self { + Self::BatchMajor { k_bh, .. } => k_bh, + Self::TokenMajor { k_heads, .. } => k_heads, + } + } + + fn seq_len(self) -> u32 { + match self { + Self::BatchMajor { seq_len, .. } => seq_len, + Self::TokenMajor { num_tokens, .. } => num_tokens, + } + } + + fn d(self) -> u32 { + match self { + Self::BatchMajor { d, .. } => d, + Self::TokenMajor { d, .. } => d, + } + } + + fn is_token_major(self) -> bool { + matches!(self, Self::TokenMajor { .. }) + } } -#[cfg(feature = "cuda")] fn resolve_rope_layout(q: &Tensor, k: &Tensor) -> Result { match (q.dims().len(), k.dims().len()) { (4, 4) => { @@ -84,6 +113,129 @@ fn resolve_rope_layout(q: &Tensor, k: &Tensor) -> Result { } } +#[cfg(not(feature = "cuda"))] +fn launch_fused_rope_metal( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor, + positions: &Tensor, + is_interleaved: bool, + rotary_dim: usize, +) -> Result<()> { + use candle_core::backend::BackendStorage; + + let layout = resolve_rope_layout(q, k)?; + let expected_positions_len = layout.positions_len(); + let pos_shape = positions.dims(); + if pos_shape.len() != 1 || pos_shape[0] != expected_positions_len { + candle_core::bail!( + "positions should be [{}], got {:?}", + expected_positions_len, + pos_shape + ); + } + if rotary_dim == 0 || rotary_dim % 2 != 0 { + candle_core::bail!( + "rotary_dim must be an even positive integer, got {}", + rotary_dim + ); + } + if rotary_dim > layout.d() as usize { + candle_core::bail!( + "rotary_dim {} exceeds head_dim {} for Q {:?}, K {:?}", + rotary_dim, + layout.d(), + q.shape(), + k.shape() + ); + } + + let positions = if positions.dtype() != DType::I64 { + positions.to_dtype(DType::I64)? + } else { + positions.clone() + }; + + if !q.is_contiguous() + || !k.is_contiguous() + || !cos.is_contiguous() + || !sin.is_contiguous() + || !positions.is_contiguous() + { + candle_core::bail!("All tensors (q, k, cos, sin, positions) must be contiguous"); + } + + let dtype = q.dtype(); + if k.dtype() != dtype || cos.dtype() != dtype || sin.dtype() != dtype { + candle_core::bail!( + "Q, K, cos, sin must have same dtype, got Q: {:?}, K: {:?}, cos: {:?}, sin: {:?}", + q.dtype(), + k.dtype(), + cos.dtype(), + sin.dtype() + ); + } + + let (q_storage, q_layout) = q.storage_and_layout(); + let (k_storage, k_layout) = k.storage_and_layout(); + let (cos_storage, cos_layout) = cos.storage_and_layout(); + let (sin_storage, sin_layout) = sin.storage_and_layout(); + let (pos_storage, pos_layout) = positions.storage_and_layout(); + + let q_metal = match &*q_storage { + candle_core::Storage::Metal(s) => s, + _ => candle_core::bail!("Q must be on Metal device"), + }; + let k_metal = match &*k_storage { + candle_core::Storage::Metal(s) => s, + _ => candle_core::bail!("K must be on Metal device"), + }; + let cos_metal = match &*cos_storage { + candle_core::Storage::Metal(s) => s, + _ => candle_core::bail!("cos must be on Metal device"), + }; + let sin_metal = match &*sin_storage { + candle_core::Storage::Metal(s) => s, + _ => candle_core::bail!("sin must be on Metal device"), + }; + let pos_metal = match &*pos_storage { + candle_core::Storage::Metal(s) => s, + _ => candle_core::bail!("positions must be on Metal device"), + }; + + let device = q_metal.device(); + let command_buffer = device.command_buffer()?; + let kernels = metal_kernels::Kernels::default(); + + metal_kernels::call_fused_rope( + device.device(), + &*command_buffer, + kernels, + dtype, + q_metal.buffer(), + q_layout.start_offset() * dtype.size_in_bytes(), + k_metal.buffer(), + k_layout.start_offset() * dtype.size_in_bytes(), + cos_metal.buffer(), + cos_layout.start_offset() * dtype.size_in_bytes(), + sin_metal.buffer(), + sin_layout.start_offset() * dtype.size_in_bytes(), + pos_metal.buffer(), + pos_layout.start_offset() * std::mem::size_of::(), + layout.q_bh(), + layout.k_bh(), + layout.seq_len(), + layout.d(), + rotary_dim as u32, + is_interleaved, + layout.is_token_major(), + ) + .map_err(|e| candle_core::Error::Msg(format!("Metal fused_rope error: {:?}", e)))?; + + Ok(()) +} + #[cfg(feature = "cuda")] fn launch_fused_rope( q: &Tensor, @@ -741,107 +893,31 @@ impl FusedRope { positions: &Tensor, is_interleaved: bool, ) -> Result<()> { - use candle_core::backend::BackendStorage; - - let (b, q_h, seq_len, d) = q.dims4()?; - let (kb, k_h, k_seq_len, kd) = k.dims4()?; - - if b != kb || seq_len != k_seq_len || d != kd { - candle_core::bail!( - "Q and K batch/seq_len/head_dim must match, got Q: {:?}, K: {:?}", - q.shape(), - k.shape() - ); - } - - // Check contiguity - bail if not contiguous (avoid hidden allocations) - if !q.is_contiguous() || !k.is_contiguous() || !cos.is_contiguous() || !sin.is_contiguous() - { - candle_core::bail!("All tensors (q, k, cos, sin) must be contiguous"); - } - - let positions = if positions.dtype() != DType::I64 { - positions.to_dtype(DType::I64)? - } else { - positions.clone() - }; - if !positions.is_contiguous() { - candle_core::bail!("positions must be contiguous"); - } - - let dtype = q.dtype(); - if k.dtype() != dtype || cos.dtype() != dtype || sin.dtype() != dtype { - candle_core::bail!("Q, K, cos, sin must have same dtype"); - } - - let q_bh = (b * q_h) as u32; - let k_bh = (b * k_h) as u32; - let seq_len_u32 = seq_len as u32; - let d_u32 = d as u32; - - // Get Metal storage, layout, and device - let (q_storage, q_layout) = q.storage_and_layout(); - let (k_storage, k_layout) = k.storage_and_layout(); - let (cos_storage, cos_layout) = cos.storage_and_layout(); - let (sin_storage, sin_layout) = sin.storage_and_layout(); - let (pos_storage, pos_layout) = positions.storage_and_layout(); - - let q_metal = match &*q_storage { - candle_core::Storage::Metal(s) => s, - _ => candle_core::bail!("Q must be on Metal device"), - }; - let k_metal = match &*k_storage { - candle_core::Storage::Metal(s) => s, - _ => candle_core::bail!("K must be on Metal device"), - }; - let cos_metal = match &*cos_storage { - candle_core::Storage::Metal(s) => s, - _ => candle_core::bail!("cos must be on Metal device"), - }; - let sin_metal = match &*sin_storage { - candle_core::Storage::Metal(s) => s, - _ => candle_core::bail!("sin must be on Metal device"), - }; - let pos_metal = match &*pos_storage { - candle_core::Storage::Metal(s) => s, - _ => candle_core::bail!("positions must be on Metal device"), - }; - - #[cfg(feature = "metal")] - let device = q_metal.device(); - #[cfg(feature = "metal")] - let command_buffer = device.command_buffer()?; - #[cfg(feature = "metal")] - let kernels = metal_kernels::Kernels::default(); - - #[cfg(feature = "metal")] - metal_kernels::call_fused_rope( - device.device(), - &*command_buffer, - kernels, - dtype, - q_metal.buffer(), - q_layout.start_offset() * dtype.size_in_bytes(), - k_metal.buffer(), - k_layout.start_offset() * dtype.size_in_bytes(), - cos_metal.buffer(), - cos_layout.start_offset() * dtype.size_in_bytes(), - sin_metal.buffer(), - sin_layout.start_offset() * dtype.size_in_bytes(), - pos_metal.buffer(), - pos_layout.start_offset() * std::mem::size_of::(), - q_bh, - k_bh, - seq_len_u32, - d_u32, + let layout = resolve_rope_layout(q, k)?; + launch_fused_rope_metal( + q, + k, + cos, + sin, + positions, is_interleaved, + layout.d() as usize, ) - .map_err(|e| candle_core::Error::Msg(format!("Metal fused_rope error: {:?}", e)))?; - - // Note: We don't call commit/wait - candle's MetalDevice manages command buffer lifecycle - // The command buffer will be committed when candle synchronizes or at the end of the operation + } - Ok(()) + /// Apply fused rotary embedding in-place to only the leading `rotary_dim` + /// channels of Q/K tensors on non-CUDA backends. + #[cfg(not(feature = "cuda"))] + pub fn apply_inplace_partial( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor, + positions: &Tensor, + is_interleaved: bool, + rotary_dim: usize, + ) -> Result<()> { + launch_fused_rope_metal(q, k, cos, sin, positions, is_interleaved, rotary_dim) } /// Apply fused rotary embedding (Metal version) - returns new tensors diff --git a/src/metal-kernels/src/fused_rope.metal b/src/metal-kernels/src/fused_rope.metal index bdb9737..884ed8f 100644 --- a/src/metal-kernels/src/fused_rope.metal +++ b/src/metal-kernels/src/fused_rope.metal @@ -21,6 +21,70 @@ #include using namespace metal; +constant uint ROPE_LAYOUT_BATCH_MAJOR = 0; +constant uint ROPE_LAYOUT_TOKEN_MAJOR = 1; + +inline uint rope_interleaved_pair_index( + uint local_idx, + uint num_heads, + uint seq_len, + uint full_pairs, + uint rotary_pairs, + uint layout, + thread uint& t_idx, + thread uint& d_idx +) { + if (layout == ROPE_LAYOUT_TOKEN_MAJOR) { + const uint pairs_per_token = num_heads * rotary_pairs; + t_idx = local_idx / pairs_per_token; + const uint rem = local_idx % pairs_per_token; + const uint h_idx = rem / rotary_pairs; + d_idx = rem % rotary_pairs; + return (t_idx * num_heads + h_idx) * full_pairs + d_idx; + } + + const uint pairs_per_bh = seq_len * rotary_pairs; + const uint bh_idx = local_idx / pairs_per_bh; + const uint rem = local_idx % pairs_per_bh; + t_idx = rem / rotary_pairs; + d_idx = rem % rotary_pairs; + return (bh_idx * seq_len + t_idx) * full_pairs + d_idx; +} + +inline void rope_non_interleaved_indices( + uint local_idx, + uint num_heads, + uint seq_len, + uint d, + uint rotary_pairs, + uint layout, + thread uint& t_idx, + thread uint& i_d, + thread uint& i1, + thread uint& i2 +) { + uint base; + + if (layout == ROPE_LAYOUT_TOKEN_MAJOR) { + const uint pairs_per_token = num_heads * rotary_pairs; + t_idx = local_idx / pairs_per_token; + const uint rem = local_idx % pairs_per_token; + const uint h_idx = rem / rotary_pairs; + i_d = rem % rotary_pairs; + base = (t_idx * num_heads + h_idx) * d; + } else { + const uint pairs_per_bh = seq_len * rotary_pairs; + const uint bh_idx = local_idx / pairs_per_bh; + const uint rem = local_idx % pairs_per_bh; + t_idx = rem / rotary_pairs; + i_d = rem % rotary_pairs; + base = (bh_idx * seq_len + t_idx) * d; + } + + i1 = base + i_d; + i2 = base + rotary_pairs + i_d; +} + // ============================================================================ // Interleaved RoPE with Position Selection // Adjacent pairs: (x0, x1), (x2, x3), ... @@ -38,11 +102,14 @@ kernel void fused_rope_i_f32( constant uint& k_bh [[buffer(6)]], constant uint& seq_len [[buffer(7)]], constant uint& d [[buffer(8)]], + constant uint& rotary_dim [[buffer(9)]], + constant uint& layout [[buffer(10)]], uint idx [[thread_position_in_grid]] ) { - const uint half_d = d / 2; - const uint q_num_pairs = q_bh * seq_len * half_d; - const uint k_num_pairs = k_bh * seq_len * half_d; + const uint full_pairs = d / 2; + const uint rotary_pairs = rotary_dim / 2; + const uint q_num_pairs = q_bh * seq_len * rotary_pairs; + const uint k_num_pairs = k_bh * seq_len * rotary_pairs; const uint total_pairs = q_num_pairs + k_num_pairs; if (idx >= total_pairs) return; @@ -50,22 +117,32 @@ kernel void fused_rope_i_f32( const bool is_q = (idx < q_num_pairs); const uint local_idx = is_q ? idx : (idx - q_num_pairs); - const uint d_idx = local_idx % half_d; - const uint t_idx = (local_idx / half_d) % seq_len; + uint t_idx; + uint d_idx; + const uint pair_idx = rope_interleaved_pair_index( + local_idx, + is_q ? q_bh : k_bh, + seq_len, + full_pairs, + rotary_pairs, + layout, + t_idx, + d_idx + ); const long pos = positions[t_idx]; - const uint cs_idx = pos * half_d + d_idx; + const uint cs_idx = pos * rotary_pairs + d_idx; const float c = cos[cs_idx]; const float s = sin[cs_idx]; device float2* ptr = is_q ? q : k; - float2 v = ptr[local_idx]; + float2 v = ptr[pair_idx]; float2 result; result.x = v.x * c - v.y * s; result.y = v.x * s + v.y * c; - ptr[local_idx] = result; + ptr[pair_idx] = result; } // F16 interleaved - uses half2 for vectorized pair access, F32 compute for precision @@ -79,11 +156,14 @@ kernel void fused_rope_i_f16( constant uint& k_bh [[buffer(6)]], constant uint& seq_len [[buffer(7)]], constant uint& d [[buffer(8)]], + constant uint& rotary_dim [[buffer(9)]], + constant uint& layout [[buffer(10)]], uint idx [[thread_position_in_grid]] ) { - const uint half_d = d / 2; - const uint q_num_pairs = q_bh * seq_len * half_d; - const uint k_num_pairs = k_bh * seq_len * half_d; + const uint full_pairs = d / 2; + const uint rotary_pairs = rotary_dim / 2; + const uint q_num_pairs = q_bh * seq_len * rotary_pairs; + const uint k_num_pairs = k_bh * seq_len * rotary_pairs; const uint total_pairs = q_num_pairs + k_num_pairs; if (idx >= total_pairs) return; @@ -91,18 +171,28 @@ kernel void fused_rope_i_f16( const bool is_q = (idx < q_num_pairs); const uint local_idx = is_q ? idx : (idx - q_num_pairs); - const uint d_idx = local_idx % half_d; - const uint t_idx = (local_idx / half_d) % seq_len; + uint t_idx; + uint d_idx; + const uint pair_idx = rope_interleaved_pair_index( + local_idx, + is_q ? q_bh : k_bh, + seq_len, + full_pairs, + rotary_pairs, + layout, + t_idx, + d_idx + ); const long pos = positions[t_idx]; - const uint cs_idx = pos * half_d + d_idx; + const uint cs_idx = pos * rotary_pairs + d_idx; // F32 compute for precision (like CUDA) const float c = float(cos[cs_idx]); const float s = float(sin[cs_idx]); device half2* ptr = is_q ? q : k; - half2 v = ptr[local_idx]; + half2 v = ptr[pair_idx]; float vx = float(v.x); float vy = float(v.y); @@ -111,7 +201,7 @@ kernel void fused_rope_i_f16( result.x = half(vx * c - vy * s); result.y = half(vx * s + vy * c); - ptr[local_idx] = result; + ptr[pair_idx] = result; } // BF16 interleaved - uses Bfloat2_ for vectorized pair access, native BF16 compute @@ -125,11 +215,14 @@ kernel void fused_rope_i_bf16( constant uint& k_bh [[buffer(6)]], constant uint& seq_len [[buffer(7)]], constant uint& d [[buffer(8)]], + constant uint& rotary_dim [[buffer(9)]], + constant uint& layout [[buffer(10)]], uint idx [[thread_position_in_grid]] ) { - const uint half_d = d / 2; - const uint q_num_pairs = q_bh * seq_len * half_d; - const uint k_num_pairs = k_bh * seq_len * half_d; + const uint full_pairs = d / 2; + const uint rotary_pairs = rotary_dim / 2; + const uint q_num_pairs = q_bh * seq_len * rotary_pairs; + const uint k_num_pairs = k_bh * seq_len * rotary_pairs; const uint total_pairs = q_num_pairs + k_num_pairs; if (idx >= total_pairs) return; @@ -137,24 +230,34 @@ kernel void fused_rope_i_bf16( const bool is_q = (idx < q_num_pairs); const uint local_idx = is_q ? idx : (idx - q_num_pairs); - const uint d_idx = local_idx % half_d; - const uint t_idx = (local_idx / half_d) % seq_len; + uint t_idx; + uint d_idx; + const uint pair_idx = rope_interleaved_pair_index( + local_idx, + is_q ? q_bh : k_bh, + seq_len, + full_pairs, + rotary_pairs, + layout, + t_idx, + d_idx + ); const long pos = positions[t_idx]; - const uint cs_idx = pos * half_d + d_idx; + const uint cs_idx = pos * rotary_pairs + d_idx; // Native BF16 compute (like CUDA's __hmul, __hsub, __hadd) const bfloat16_t c = cos[cs_idx]; const bfloat16_t s = sin[cs_idx]; device Bfloat2_* ptr = is_q ? q : k; - Bfloat2_ v = ptr[local_idx]; + Bfloat2_ v = ptr[pair_idx]; Bfloat2_ result; result.x = v.x * c - v.y * s; result.y = v.x * s + v.y * c; - ptr[local_idx] = result; + ptr[pair_idx] = result; } // ============================================================================ @@ -173,11 +276,13 @@ kernel void fused_rope_f32( constant uint& k_bh [[buffer(6)]], constant uint& seq_len [[buffer(7)]], constant uint& d [[buffer(8)]], + constant uint& rotary_dim [[buffer(9)]], + constant uint& layout [[buffer(10)]], uint idx [[thread_position_in_grid]] ) { - const uint half_d = d / 2; - const uint q_pairs = q_bh * seq_len * half_d; - const uint k_pairs = k_bh * seq_len * half_d; + const uint rotary_pairs = rotary_dim / 2; + const uint q_pairs = q_bh * seq_len * rotary_pairs; + const uint k_pairs = k_bh * seq_len * rotary_pairs; const uint total_pairs = q_pairs + k_pairs; if (idx >= total_pairs) return; @@ -185,21 +290,28 @@ kernel void fused_rope_f32( const bool is_q = (idx < q_pairs); const uint local_idx = is_q ? idx : (idx - q_pairs); - const uint pairs_per_bh = seq_len * half_d; - const uint i_bh = local_idx / pairs_per_bh; - const uint remainder = local_idx % pairs_per_bh; - const uint i_t = remainder / half_d; - const uint i_d = remainder % half_d; + uint i_t; + uint i_d; + uint i1; + uint i2; + rope_non_interleaved_indices( + local_idx, + is_q ? q_bh : k_bh, + seq_len, + d, + rotary_pairs, + layout, + i_t, + i_d, + i1, + i2 + ); const long pos = positions[i_t]; - const uint cs_idx = pos * half_d + i_d; + const uint cs_idx = pos * rotary_pairs + i_d; const float c = cos[cs_idx]; const float s = sin[cs_idx]; - const uint td = seq_len * d; - const uint i1 = i_bh * td + i_t * d + i_d; - const uint i2 = i1 + half_d; - device float* ptr = is_q ? q : k; float x1 = ptr[i1]; float x2 = ptr[i2]; @@ -220,11 +332,13 @@ kernel void fused_rope_f16( constant uint& k_bh [[buffer(6)]], constant uint& seq_len [[buffer(7)]], constant uint& d [[buffer(8)]], + constant uint& rotary_dim [[buffer(9)]], + constant uint& layout [[buffer(10)]], uint idx [[thread_position_in_grid]] ) { - const uint half_d = d / 2; - const uint q_pairs = q_bh * seq_len * half_d; - const uint k_pairs = k_bh * seq_len * half_d; + const uint rotary_pairs = rotary_dim / 2; + const uint q_pairs = q_bh * seq_len * rotary_pairs; + const uint k_pairs = k_bh * seq_len * rotary_pairs; const uint total_pairs = q_pairs + k_pairs; if (idx >= total_pairs) return; @@ -232,23 +346,30 @@ kernel void fused_rope_f16( const bool is_q = (idx < q_pairs); const uint local_idx = is_q ? idx : (idx - q_pairs); - const uint pairs_per_bh = seq_len * half_d; - const uint i_bh = local_idx / pairs_per_bh; - const uint remainder = local_idx % pairs_per_bh; - const uint i_t = remainder / half_d; - const uint i_d = remainder % half_d; + uint i_t; + uint i_d; + uint i1; + uint i2; + rope_non_interleaved_indices( + local_idx, + is_q ? q_bh : k_bh, + seq_len, + d, + rotary_pairs, + layout, + i_t, + i_d, + i1, + i2 + ); const long pos = positions[i_t]; - const uint cs_idx = pos * half_d + i_d; + const uint cs_idx = pos * rotary_pairs + i_d; // F32 compute for precision const float c = float(cos[cs_idx]); const float s = float(sin[cs_idx]); - const uint td = seq_len * d; - const uint i1 = i_bh * td + i_t * d + i_d; - const uint i2 = i1 + half_d; - device half* ptr = is_q ? q : k; float x1 = float(ptr[i1]); float x2 = float(ptr[i2]); @@ -269,11 +390,13 @@ kernel void fused_rope_bf16( constant uint& k_bh [[buffer(6)]], constant uint& seq_len [[buffer(7)]], constant uint& d [[buffer(8)]], + constant uint& rotary_dim [[buffer(9)]], + constant uint& layout [[buffer(10)]], uint idx [[thread_position_in_grid]] ) { - const uint half_d = d / 2; - const uint q_pairs = q_bh * seq_len * half_d; - const uint k_pairs = k_bh * seq_len * half_d; + const uint rotary_pairs = rotary_dim / 2; + const uint q_pairs = q_bh * seq_len * rotary_pairs; + const uint k_pairs = k_bh * seq_len * rotary_pairs; const uint total_pairs = q_pairs + k_pairs; if (idx >= total_pairs) return; @@ -281,23 +404,30 @@ kernel void fused_rope_bf16( const bool is_q = (idx < q_pairs); const uint local_idx = is_q ? idx : (idx - q_pairs); - const uint pairs_per_bh = seq_len * half_d; - const uint i_bh = local_idx / pairs_per_bh; - const uint remainder = local_idx % pairs_per_bh; - const uint i_t = remainder / half_d; - const uint i_d = remainder % half_d; + uint i_t; + uint i_d; + uint i1; + uint i2; + rope_non_interleaved_indices( + local_idx, + is_q ? q_bh : k_bh, + seq_len, + d, + rotary_pairs, + layout, + i_t, + i_d, + i1, + i2 + ); const long pos = positions[i_t]; - const uint cs_idx = pos * half_d + i_d; + const uint cs_idx = pos * rotary_pairs + i_d; // Native BF16 compute const bfloat16_t c = cos[cs_idx]; const bfloat16_t s = sin[cs_idx]; - const uint td = seq_len * d; - const uint i1 = i_bh * td + i_t * d + i_d; - const uint i2 = i1 + half_d; - device bfloat16_t* ptr = is_q ? q : k; bfloat16_t x1 = ptr[i1]; bfloat16_t x2 = ptr[i2]; diff --git a/src/metal-kernels/src/lib.rs b/src/metal-kernels/src/lib.rs index 5417ace..9841f20 100644 --- a/src/metal-kernels/src/lib.rs +++ b/src/metal-kernels/src/lib.rs @@ -1069,7 +1069,9 @@ pub fn call_update_scales_per_head( /// * `k_bh` - batch * num_kv_heads /// * `seq_len` - sequence length /// * `d` - head_dim +/// * `rotary_dim` - number of channels to rotate /// * `is_interleaved` - if true, use interleaved RoPE layout +/// * `is_token_major` - if true, Q/K layout is [tokens, heads, dim] #[allow(clippy::too_many_arguments)] pub fn call_fused_rope( device: &Device, @@ -1090,7 +1092,9 @@ pub fn call_fused_rope( k_bh: u32, seq_len: u32, d: u32, + rotary_dim: u32, is_interleaved: bool, + is_token_major: bool, ) -> Result<(), MetalKernelError> { let type_name = match ty { DType::F32 => "f32", @@ -1126,13 +1130,15 @@ pub fn call_fused_rope( q_bh, k_bh, seq_len, - d + d, + rotary_dim, + if is_token_major { 1u32 } else { 0u32 } ) ); // Calculate total number of pairs - let half_d = d / 2; - let total_pairs = ((q_bh + k_bh) * seq_len * half_d) as u64; + let rotary_pairs = rotary_dim / 2; + let total_pairs = ((q_bh + k_bh) * seq_len * rotary_pairs) as u64; // Dispatch with 256 threads per threadgroup let threads_per_threadgroup = MTLSize { From ec54281d5ffac1fc7d274e7d26e6d37b6a247163 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 11 Mar 2026 06:17:48 +0000 Subject: [PATCH 6/7] Fix build for sm_80 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c916b4d..3b93ec7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ rayon="1.10.0" kernels = { path = "./src/kernels", version="0.4.2", optional = true} metal = { version = "0.27.0", features = ["mps"], optional = true } metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true} -flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "bf1db0a", optional = true } +flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "e2b967a", optional = true } [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:kernels"] From 005066057bab85803729a99c4421dc22b29c1bce Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 11 Mar 2026 08:08:00 +0000 Subject: [PATCH 7/7] Simplify flashattn feature --- Cargo.toml | 7 +- ReadMe.md | 2 +- src/fp8_linear.rs | 2 +- src/fused_rope.rs | 1 + src/kernels/src/paged_attention_v1.cu | 2 +- src/kernels/src/paged_attention_v2.cu | 2 +- src/kernels/src/prefill_paged_attn.cu | 2 +- src/kernels/src/prefill_paged_attn_opt.cu | 2 +- src/kernels/src/reshape_and_cache_kernel.cu | 4 +- src/lib.rs | 78 +++++++-------------- src/paged_attention.rs | 12 ++-- 11 files changed, 44 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3b93ec7..b3605ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "attention-rs" -version = "0.4.2" +version = "0.4.3" edition = "2021" description = "High-performance LLM attention kernels and operations (PagedAttention, Flahinfer, Mamba, MoE, RoPE) for Candle, optimized for CUDA and Metal." repository = "https://github.com/guoqingbao/attention.rs" @@ -24,14 +24,13 @@ rayon="1.10.0" kernels = { path = "./src/kernels", version="0.4.2", optional = true} metal = { version = "0.27.0", features = ["mps"], optional = true } metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true} -flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "e2b967a", optional = true } +flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "a59e803", optional = true } [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:kernels"] graph = ["cuda", "candle-core/graph"] -flash-attn = ["dep:flashattn-rs"] +flashattn = ["dep:flashattn-rs", "flashattn-rs/flash-context", "kernels/no-fp8-kvcache"] flash-decoding = ["dep:flashattn-rs", "flashattn-rs/flash-decoding", "kernels/no-fp8-kvcache"] -flash-context = ["dep:flashattn-rs", "flashattn-rs/flash-context", "kernels/no-fp8-kvcache"] no-marlin = ["dep:kernels", "kernels/no-marlin"] no-fp8-kvcache = ["dep:kernels", "kernels/no-fp8-kvcache"] metal = ["candle-core/metal", "candle-nn/metal", "dep:metal-kernels", "dep:metal"] diff --git a/ReadMe.md b/ReadMe.md index 18d086b..7764ed0 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -58,7 +58,7 @@ attention-rs = { git = "https://github.com/guoqingbao/attention.rs" } - `cuda`: Enable CUDA kernels and optimizations. - `metal`: Enable Metal kernels for Apple Silicon. -- `flash-attn`: Enable Flash Attention integration. +- `flashattn`: Enable Flash Attention integration. - `flashinfer`: Enable FlashInfer integration. - `cutlass`: Enable CUTLASS-optimized FP8 kernels (requires CUDA). diff --git a/src/fp8_linear.rs b/src/fp8_linear.rs index 0be5920..c8f2d92 100644 --- a/src/fp8_linear.rs +++ b/src/fp8_linear.rs @@ -4,7 +4,7 @@ use crate::cuda_utils; use crate::kernels::ffi; #[cfg(feature = "metal")] use crate::metal_kernels; -#[cfg(feature = "cuda")] +#[cfg(all(feature = "cuda", feature = "flashinfer"))] use candle_core::cuda_backend::cudarc::driver::CudaSlice; #[cfg(feature = "cuda")] use candle_core::cuda_backend::cudarc::driver::DevicePtr; diff --git a/src/fused_rope.rs b/src/fused_rope.rs index f0aa52e..d17a032 100644 --- a/src/fused_rope.rs +++ b/src/fused_rope.rs @@ -28,6 +28,7 @@ enum RopeLayout { }, } +#[allow(dead_code)] impl RopeLayout { fn positions_len(self) -> usize { match self { diff --git a/src/kernels/src/paged_attention_v1.cu b/src/kernels/src/paged_attention_v1.cu index e9c36d4..3bb4940 100644 --- a/src/kernels/src/paged_attention_v1.cu +++ b/src/kernels/src/paged_attention_v1.cu @@ -220,7 +220,7 @@ extern "C" void paged_attention_v1( #endif } #else - throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flash-attn or context-cache enabled)."); + throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flashattn or context-cache enabled)."); #endif } } diff --git a/src/kernels/src/paged_attention_v2.cu b/src/kernels/src/paged_attention_v2.cu index 3cba3d7..01b70e5 100644 --- a/src/kernels/src/paged_attention_v2.cu +++ b/src/kernels/src/paged_attention_v2.cu @@ -242,7 +242,7 @@ extern "C" void paged_attention_v2( #endif } #else - throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flash-attn or context-cache enabled)."); + throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flashattn or context-cache enabled)."); #endif } } diff --git a/src/kernels/src/prefill_paged_attn.cu b/src/kernels/src/prefill_paged_attn.cu index a58d63e..597f1c3 100644 --- a/src/kernels/src/prefill_paged_attn.cu +++ b/src/kernels/src/prefill_paged_attn.cu @@ -538,7 +538,7 @@ extern "C" void paged_attention_prefill( #endif } #else - throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flash-attn or context-cache enabled)."); + throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flashattn or context-cache enabled)."); #endif } else { if (dtype == 2) { diff --git a/src/kernels/src/prefill_paged_attn_opt.cu b/src/kernels/src/prefill_paged_attn_opt.cu index 1698d3e..f5cdfd9 100644 --- a/src/kernels/src/prefill_paged_attn_opt.cu +++ b/src/kernels/src/prefill_paged_attn_opt.cu @@ -582,7 +582,7 @@ extern "C" void paged_attention_prefill_opt( #endif } #else - throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flash-attn or context-cache enabled)."); + throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flashattn or context-cache enabled)."); #endif } else { if (dtype == 2) { diff --git a/src/kernels/src/reshape_and_cache_kernel.cu b/src/kernels/src/reshape_and_cache_kernel.cu index b2fccec..56122eb 100644 --- a/src/kernels/src/reshape_and_cache_kernel.cu +++ b/src/kernels/src/reshape_and_cache_kernel.cu @@ -186,7 +186,7 @@ extern "C" void call_reshape_and_cache( CALL_RESHAPE_AND_CACHE(float, uint8_t); } #else - throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flash-attn or context-cache enabled)."); + throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flashattn or context-cache enabled)."); #endif } else { if (dtype == 0){ @@ -235,7 +235,7 @@ extern "C" void call_reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH(float, uint8_t); } #else - throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flash-attn or context-cache enalbed)."); + throw std::runtime_error("Error: FP8 KV-cache is disabled (possiblly because flashattn or context-cache enalbed)."); #endif } else { if (dtype == 0){ diff --git a/src/lib.rs b/src/lib.rs index 10f2c40..104507a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,7 +142,7 @@ impl PagedAttention { } } - #[cfg(any(feature = "flash-attn", feature = "flashinfer"))] + #[cfg(any(feature = "flashattn", feature = "flashinfer"))] fn packed_qkv( query: &Tensor, key: &Tensor, @@ -347,7 +347,7 @@ impl PagedAttention { Tensor::cat(&vec_attn, 2)?.contiguous()?.transpose(1, 2) } - #[cfg(feature = "flash-attn")] + #[cfg(feature = "flashattn")] pub fn flash_var_len( &self, query: &Tensor, @@ -388,7 +388,7 @@ impl PagedAttention { } } - #[cfg(feature = "flash-attn")] + #[cfg(feature = "flashattn")] pub fn flash_forward( &self, query: &Tensor, @@ -422,8 +422,8 @@ impl PagedAttention { return self.flash_var_len(&query, &key, &value, input_metadata, softcapping); } - #[cfg(feature = "flash-decoding")] if input_metadata.is_prefill { + // prefill with kvcache return flashattn_rs::flash_attn_with_kvcache_advanced( &query, key_cache.as_ref().unwrap(), @@ -443,36 +443,27 @@ impl PagedAttention { ); } - #[cfg(not(feature = "flash-decoding"))] - if input_metadata.is_prefill { - candle_core::bail!("Invalid pattern for flash_forward"); - } - - #[cfg(feature = "flash-decoding")] - { - let block_tables = input_metadata.block_tables.as_ref().unwrap(); - let context_lens = input_metadata.context_lens.as_ref().unwrap(); + // Decoding with kvcache + let block_tables = input_metadata.block_tables.as_ref().unwrap(); + let context_lens = input_metadata.context_lens.as_ref().unwrap(); - flashattn_rs::flash_attn_with_kvcache_advanced( - &query.unsqueeze(1)?, //(batch_size, seqlen_q, num_heads_q, head_size) - key_cache.as_ref().unwrap(), - value_cache.as_ref().unwrap(), - context_lens, - block_tables, - None, - None, - self.scale as f32, - false, - self.sliding_window, - window_size_right, - None, - softcap, - 0, - None, - ) - } - #[cfg(not(feature = "flash-decoding"))] - candle_core::bail!("Invalid pattern for flash_forward") + flashattn_rs::flash_attn_with_kvcache_advanced( + &query.unsqueeze(1)?, //(batch_size, seqlen_q, num_heads_q, head_size) + key_cache.as_ref().unwrap(), + value_cache.as_ref().unwrap(), + context_lens, + block_tables, + None, + None, + self.scale as f32, + false, + self.sliding_window, + window_size_right, + None, + softcap, + 0, + None, + ) } #[allow(clippy::too_many_arguments)] @@ -600,7 +591,7 @@ impl PagedAttention { } } - #[cfg(feature = "flash-decoding")] + #[cfg(feature = "flashattn")] if !input_metadata.disable_flash_attn.unwrap_or(false) { return self.flash_forward( query, @@ -613,23 +604,6 @@ impl PagedAttention { ); } - if !input_metadata.disable_flash_attn.unwrap_or(false) - && input_metadata.is_prefill - && input_metadata.block_tables.is_none() - { - // non context-cache prefill with flash-attn - #[cfg(feature = "flash-attn")] - return self.flash_forward( - query, - key, - value, - key_cache, - value_cache, - input_metadata, - softcapping, - ); - } - let mut att = if input_metadata.is_prefill && input_metadata.block_tables.is_none() { //no context cache, prefill with naive scale-dot-product attention Some(self.sdp_prefill( @@ -685,7 +659,7 @@ impl PagedAttention { //decoding with paged-attn - //if flash-decoding (flash-attn with prefill kvcache) feature not enabled, use our custom paged attention for chunked prefill + //if flashattn (flashattn with prefill kvcache) feature not enabled, use our custom paged attention for chunked prefill let cu_seqlens_q = if input_metadata.is_prefill && input_metadata.block_tables.is_some() { assert!( input_metadata.cu_seqlens_q.as_ref().is_some(), diff --git a/src/paged_attention.rs b/src/paged_attention.rs index 38e2bab..ddc337d 100644 --- a/src/paged_attention.rs +++ b/src/paged_attention.rs @@ -885,7 +885,7 @@ impl ReshapeCache { ) } - #[cfg(feature = "flash-decoding")] + #[cfg(feature = "flashattn")] if kc_rank != 4 { candle::bail!( "flash-attention expects `key_cache` tensor to be of rank 4 \ @@ -893,7 +893,7 @@ impl ReshapeCache { ) } - #[cfg(not(feature = "flash-decoding"))] + #[cfg(not(feature = "flashattn"))] if kc_rank != 5 { candle::bail!( "paged-attention expects `key_cache` tensor to be of rank 5 \ @@ -923,14 +923,14 @@ impl ReshapeCache { candle::bail!("shape mismatch k {:?} and v {:?}", k_l.shape(), v_l.shape()) } - #[cfg(feature = "flash-decoding")] + #[cfg(feature = "flashattn")] let (block_size, _x) = { // [num_blocks, block_size, num_heads, head_size] let (_, block_size, _, _) = kc_l.shape().dims4()?; (block_size, 1) }; - #[cfg(not(feature = "flash-decoding"))] + #[cfg(not(feature = "flashattn"))] let (block_size, x) = { let (num_blocks, num_heads_kc, head_size_kc, block_size, x) = kc_l.shape().dims5()?; if num_heads_kc != num_heads || head_size_kc != head_size / x { @@ -1005,7 +1005,7 @@ impl ReshapeCache { }; unsafe { - #[cfg(feature = "flash-decoding")] + #[cfg(feature = "flashattn")] { assert!( k_scales_ptr.is_null() && v_scales_ptr.is_null(), @@ -1036,7 +1036,7 @@ impl ReshapeCache { *dev.cu_stream() as i64, ); } - #[cfg(not(feature = "flash-decoding"))] + #[cfg(not(feature = "flashattn"))] { kernels::ffi::call_reshape_and_cache( k_ptr,