diff --git a/csrc/moe_utils_binding.cu b/csrc/moe_utils_binding.cu index 8cfd00f3eb..bfe2610fd0 100644 --- a/csrc/moe_utils_binding.cu +++ b/csrc/moe_utils_binding.cu @@ -345,3 +345,63 @@ void moe_sort( } TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_sort, moe_sort); + +#ifdef ENABLE_BF16 +// ============================ fused_topk_raw_logits bindings ============================ +// Use TRTLLM routingRenormalize code path (same kernel family as non-routed TRTLLM MoE). +// +// Notes: +// - This function computes top-k scores/weights from raw routing logits (mPtrScores path). +// - mPtrTopKIds is intentionally left nullptr to force score-driven routing selection. +// - topk_packed_ptr is required by routing kernels for large-token paths. +void fused_topk_raw_logits_trtllm_renormalize( + int64_t topk_weights_ptr, int64_t topk_packed_ptr, int64_t gating_output_ptr, + int32_t num_tokens, int32_t num_experts, int32_t top_k, bool renormalize, bool use_pdl, + int32_t tile_tokens_dim, int64_t expert_counts_ptr, int64_t permuted_idx_size_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, int64_t permuted_idx_to_token_idx_ptr, + int64_t cta_idx_to_batch_idx_ptr, int64_t cta_idx_to_mn_limit_ptr, + int64_t num_non_exiting_ctas_ptr, + // Optional explicit CUDA stream pointer for CUDA graph compatibility. + // If 0, use TVM FFI current stream. + int64_t cuda_stream_ptr) { + moe::dev::routing::routingRenormalize::Data routingData; + + // Match TRTLLM non-routed MoE routing defaults. + routingData.mDtypeExpW = batchedGemm::trtllm::gen::Dtype::Bfloat16; + routingData.mDtypeElt = batchedGemm::trtllm::gen::Dtype::Bfloat16; + routingData.mUsePdl = use_pdl; + routingData.mDoSoftmaxBeforeTopK = false; + routingData.mNormTopkProb = false; + routingData.mApplySoftmaxAfterTopK = renormalize; + + routingData.mPtrScores = reinterpret_cast(gating_output_ptr); + routingData.mPtrTopKWeights = reinterpret_cast(topk_weights_ptr); + routingData.mPtrTopKIds = nullptr; + routingData.mPtrTopKPacked = reinterpret_cast(topk_packed_ptr); + + routingData.mPtrExpertCounts = reinterpret_cast(expert_counts_ptr); + routingData.mPtrPermutedIdxSize = reinterpret_cast(permuted_idx_size_ptr); + routingData.mPtrExpandedIdxToPermutedIdx = + reinterpret_cast(expanded_idx_to_permuted_idx_ptr); + routingData.mPtrPermutedIdxToTokenIdx = reinterpret_cast(permuted_idx_to_token_idx_ptr); + routingData.mPtrCtaIdxXyToBatchIdx = reinterpret_cast(cta_idx_to_batch_idx_ptr); + routingData.mPtrCtaIdxXyToMnLimit = reinterpret_cast(cta_idx_to_mn_limit_ptr); + routingData.mPtrNumNonExitingCtas = reinterpret_cast(num_non_exiting_ctas_ptr); + + routingData.mNumTokens = num_tokens; + routingData.mNumExperts = num_experts; + routingData.mTopK = top_k; + routingData.mPaddingLog2 = computeLog2(tile_tokens_dim); + routingData.mTileTokensDim = tile_tokens_dim; + routingData.mLocalExpertsStartIdx = 0; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = num_experts; + + cudaStream_t stream = + cuda_stream_ptr != 0 ? reinterpret_cast(cuda_stream_ptr) : get_current_stream(); + moe::dev::routing::routingRenormalize::run(routingData, stream); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_fused_topk_raw_logits_trtllm_renormalize, + fused_topk_raw_logits_trtllm_renormalize); +#endif diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index e2b4cab3d6..05bc3f590b 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -40,6 +40,9 @@ from .fused_routing_dsv3 import ( # noqa: F401 fused_topk_deepseek as fused_topk_deepseek, ) +from .raw_logits_topk import ( + fused_topk_raw_logits as fused_topk_raw_logits, +) # CuteDSL MoE APIs (conditionally imported if cute_dsl available) try: @@ -74,6 +77,7 @@ "trtllm_fp8_per_tensor_scale_moe", "trtllm_mxint4_block_scale_moe", "fused_topk_deepseek", + "fused_topk_raw_logits", ] # Add CuteDSL exports if available diff --git a/flashinfer/fused_moe/raw_logits_topk.py b/flashinfer/fused_moe/raw_logits_topk.py new file mode 100644 index 0000000000..6a5e052872 --- /dev/null +++ b/flashinfer/fused_moe/raw_logits_topk.py @@ -0,0 +1,289 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import functools + +import torch + +from flashinfer.api_logging import flashinfer_api +from flashinfer.jit.moe_utils import gen_moe_utils_module +from flashinfer.utils import device_support_pdl + +_ROUTING_TILE_TOKENS_DIM = 128 +_workspace_pool: dict[tuple[str, int | None], "_RawLogitsTopkWorkspace"] = {} + + +@functools.lru_cache(maxsize=1) +def _get_moe_utils_module(): + spec = gen_moe_utils_module() + return spec.build_and_load() + + +def _get_cuda_stream_ptr(device: torch.device) -> int: + return torch.cuda.current_stream(device=device).cuda_stream + + +class _RawLogitsTopkWorkspace: + def __init__(self, device: torch.device) -> None: + self.device = device + self.topk_weights_bf16 = torch.empty((0,), dtype=torch.bfloat16, device=device) + self.topk_packed = torch.empty((0,), dtype=torch.int32, device=device) + self.expert_counts = torch.empty((0,), dtype=torch.int32, device=device) + self.permuted_idx_size = torch.empty((0,), dtype=torch.int32, device=device) + self.expanded_idx_to_permuted_idx = torch.empty( + (0,), dtype=torch.int32, device=device + ) + self.permuted_idx_to_token_idx = torch.empty((0,), dtype=torch.int32, device=device) + self.cta_idx_to_batch_idx = torch.empty((0,), dtype=torch.int32, device=device) + self.cta_idx_to_mn_limit = torch.empty((0,), dtype=torch.int32, device=device) + self.num_non_exiting_ctas = torch.empty((0,), dtype=torch.int32, device=device) + + @staticmethod + def _ensure_capacity(tensor: torch.Tensor, numel: int) -> torch.Tensor: + if tensor.numel() >= numel: + return tensor + return torch.empty((numel,), dtype=tensor.dtype, device=tensor.device) + + def get_views( + self, + num_tokens: int, + topk: int, + num_experts: int, + max_num_tiles: int, + max_num_permuted_tokens: int, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + expanded_size = num_tokens * topk + expert_counts_size = max(512, 2 * num_experts) + + self.topk_weights_bf16 = self._ensure_capacity(self.topk_weights_bf16, expanded_size) + self.topk_packed = self._ensure_capacity(self.topk_packed, expanded_size) + self.expert_counts = self._ensure_capacity(self.expert_counts, expert_counts_size) + self.permuted_idx_size = self._ensure_capacity(self.permuted_idx_size, 1) + self.expanded_idx_to_permuted_idx = self._ensure_capacity( + self.expanded_idx_to_permuted_idx, expanded_size + ) + self.permuted_idx_to_token_idx = self._ensure_capacity( + self.permuted_idx_to_token_idx, max_num_permuted_tokens + ) + self.cta_idx_to_batch_idx = self._ensure_capacity( + self.cta_idx_to_batch_idx, max_num_tiles + ) + self.cta_idx_to_mn_limit = self._ensure_capacity(self.cta_idx_to_mn_limit, max_num_tiles) + self.num_non_exiting_ctas = self._ensure_capacity(self.num_non_exiting_ctas, 1) + + return ( + self.topk_weights_bf16[:expanded_size].view(num_tokens, topk), + self.topk_packed[:expanded_size].view(num_tokens, topk), + self.expert_counts[:expert_counts_size], + self.permuted_idx_size[:1], + self.expanded_idx_to_permuted_idx[:expanded_size], + self.permuted_idx_to_token_idx[:max_num_permuted_tokens], + self.cta_idx_to_batch_idx[:max_num_tiles], + self.cta_idx_to_mn_limit[:max_num_tiles], + self.num_non_exiting_ctas[:1], + ) + + +def _get_workspace(device: torch.device) -> _RawLogitsTopkWorkspace: + key = (device.type, device.index) + ws = _workspace_pool.get(key) + if ws is None: + ws = _RawLogitsTopkWorkspace(device) + _workspace_pool[key] = ws + return ws + + +def _get_max_num_tiles( + num_tokens: int, + top_k: int, + num_local_experts: int, + tile_size: int, +) -> int: + # Mirrors TRTLLM GroupedGemmInputsHelper.get_max_num_tiles. + num_expanded_tokens = num_tokens * top_k + if num_expanded_tokens <= num_local_experts: + return num_expanded_tokens + num_remaining_tokens = num_expanded_tokens - num_local_experts + return num_local_experts + (num_remaining_tokens + tile_size - 1) // tile_size + + +def _validate_args( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + if topk_weights.ndim != 2 or topk_ids.ndim != 2 or gating_output.ndim != 2: + raise ValueError( + "Expected 2D tensors for topk_weights, topk_ids, and gating_output." + ) + if topk_weights.shape != topk_ids.shape: + raise ValueError( + f"topk_weights/topk_ids shape mismatch: {topk_weights.shape} vs {topk_ids.shape}" + ) + if topk_weights.shape[0] != gating_output.shape[0]: + raise ValueError( + "Batch size mismatch: " + f"{topk_weights.shape[0]} (output) vs {gating_output.shape[0]} (gating_output)" + ) + if topk_weights.dtype != torch.float32: + raise ValueError( + f"Expected topk_weights dtype float32, got {topk_weights.dtype}" + ) + if not topk_weights.is_contiguous(): + raise ValueError("Expected topk_weights to be contiguous.") + if topk_ids.dtype not in (torch.int32, torch.int64): + raise ValueError( + f"Expected topk_ids dtype int32 or int64, got {topk_ids.dtype}" + ) + if not topk_ids.is_contiguous(): + raise ValueError("Expected topk_ids to be contiguous.") + if gating_output.dtype != torch.bfloat16: + raise ValueError( + "TRTLLM routingRenormalize path expects bf16 gating_output, got " + f"{gating_output.dtype}" + ) + if ( + topk_weights.device != gating_output.device + or topk_ids.device != gating_output.device + ): + raise ValueError( + "topk_weights, topk_ids, and gating_output must be on the same device." + ) + if gating_output.device.type != "cuda": + raise ValueError("TRTLLM routingRenormalize path only supports CUDA tensors.") + if not isinstance(renormalize, bool): + raise ValueError(f"renormalize must be bool, got {type(renormalize)}") + + topk = topk_weights.shape[1] + num_experts = gating_output.shape[1] + if topk < 1: + raise ValueError(f"Invalid top-k: {topk}") + if topk > num_experts: + raise ValueError(f"Invalid top-k {topk} for num_experts={num_experts}") + if topk > 10: + raise ValueError(f"TRTLLM routingRenormalize supports top-k <= 10, got {topk}") + if num_experts > 512: + raise ValueError( + f"TRTLLM routingRenormalize supports num_experts <= 512, got {num_experts}" + ) + if num_experts % 4 != 0: + raise ValueError( + f"TRTLLM routingRenormalize expects num_experts % 4 == 0, got {num_experts}" + ) + + +@flashinfer_api +def fused_topk_raw_logits( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = True, +) -> None: + """TRTLLM routingRenormalize raw-logits top-k. + + This API intentionally uses only TRTLLM's routingRenormalize path. There is no + fallback path in this implementation. + + Supported configuration: + - ``gating_output`` dtype ``torch.bfloat16`` + - routing score mode fixed to raw-logits + optional post-topk softmax + - routing tile size fixed to 128 + """ + + _validate_args( + topk_weights=topk_weights, + topk_ids=topk_ids, + gating_output=gating_output, + renormalize=renormalize, + ) + + num_tokens, num_experts = gating_output.shape + topk = topk_weights.shape[1] + max_num_tiles = _get_max_num_tiles( + num_tokens=num_tokens, + top_k=topk, + num_local_experts=num_experts, + tile_size=_ROUTING_TILE_TOKENS_DIM, + ) + max_num_permuted_tokens = max_num_tiles * _ROUTING_TILE_TOKENS_DIM + use_pdl = device_support_pdl(gating_output.device) + + ( + topk_weights_bf16, + topk_packed, + expert_counts, + permuted_idx_size, + expanded_idx_to_permuted_idx, + permuted_idx_to_token_idx, + cta_idx_to_batch_idx, + cta_idx_to_mn_limit, + num_non_exiting_ctas, + ) = _get_workspace(gating_output.device).get_views( + num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + max_num_tiles=max_num_tiles, + max_num_permuted_tokens=max_num_permuted_tokens, + ) + + if not gating_output.is_contiguous(): + gating_output = gating_output.contiguous() + # Initialize tail entries so searchsorted sees a monotonic full-length array. + cta_idx_to_batch_idx.zero_() + cta_idx_to_mn_limit.fill_(torch.iinfo(torch.int32).max) + _get_moe_utils_module()["flashinfer_fused_topk_raw_logits_trtllm_renormalize"]( + topk_weights_bf16.data_ptr(), + topk_packed.data_ptr(), + gating_output.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + use_pdl, + _ROUTING_TILE_TOKENS_DIM, + expert_counts.data_ptr(), + permuted_idx_size.data_ptr(), + expanded_idx_to_permuted_idx.data_ptr(), + permuted_idx_to_token_idx.data_ptr(), + cta_idx_to_batch_idx.data_ptr(), + cta_idx_to_mn_limit.data_ptr(), + num_non_exiting_ctas.data_ptr(), + _get_cuda_stream_ptr(gating_output.device), + ) + + # Recover expert ids entirely on-device from routing metadata. We intentionally + # avoid host syncs (e.g. tensor.item()) to keep this CUDA-graph safe. + expanded = expanded_idx_to_permuted_idx.view(num_tokens, topk) + cta_idx = torch.searchsorted(cta_idx_to_mn_limit, expanded, right=True) + topk_ids_i32 = cta_idx_to_batch_idx[cta_idx].to(torch.int32) + topk_weights.copy_(topk_weights_bf16.float()) + if topk_ids.dtype == torch.int32: + topk_ids.copy_(topk_ids_i32) + else: + topk_ids.copy_(topk_ids_i32.to(torch.int64)) diff --git a/flashinfer/jit/moe_utils.py b/flashinfer/jit/moe_utils.py index 9e0cc9e19c..e3ca16e0d6 100644 --- a/flashinfer/jit/moe_utils.py +++ b/flashinfer/jit/moe_utils.py @@ -77,6 +77,8 @@ def gen_moe_utils_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", # Routing kernels for moe_sort jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_deepseek.cu", + # Routing kernels for raw-logits top-k (TRTLLM renormalize path) + jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_renormalize.cu", ], extra_cuda_cflags=nvcc_flags, extra_include_paths=[ diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 540e015621..0a70d20f63 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -20,6 +20,7 @@ import torch import flashinfer +from flashinfer.fused_moe import fused_topk_raw_logits from flashinfer.topk import can_implement_filtered_topk from flashinfer.utils import get_compute_capability @@ -258,6 +259,65 @@ def test_top_k_vs_torch_topk_compatibility(): assert accuracy >= 0.98 +# ===================== TRTLLM MoE Raw Logits TopK Tests ===================== + + +def _skip_if_not_sm100_or_sm103() -> None: + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + + +@pytest.mark.parametrize("num_tokens", [1, 7, 31, 127, 257, 777, 1537]) +@pytest.mark.parametrize( + "num_experts, top_k", + [ + (8, 2), + (16, 4), + (64, 8), + (128, 8), + ], +) +@pytest.mark.parametrize("renormalize", [False, True]) +def test_fused_topk_raw_logits_matches_torch_reference( + num_tokens, num_experts, top_k, renormalize +): + _skip_if_not_sm100_or_sm103() + + device = torch.device("cuda:0") + token_ids = torch.arange(num_tokens, device=device, dtype=torch.int32).unsqueeze(1) + expert_ids = torch.arange(num_experts, device=device, dtype=torch.int32).unsqueeze( + 0 + ) + routing_logits = ((expert_ids + token_ids) % num_experts).to(torch.bfloat16) + topk_weights = torch.empty(num_tokens, top_k, device=device, dtype=torch.float32) + topk_ids = torch.empty(num_tokens, top_k, device=device, dtype=torch.int32) + + fused_topk_raw_logits( + topk_weights, topk_ids, routing_logits, renormalize=renormalize + ) + + ref_values, ref_ids = torch.topk( + routing_logits.float(), k=top_k, dim=-1, sorted=False + ) + ref_weights = torch.softmax(ref_values, dim=-1) if renormalize else ref_values + if renormalize: + # routingRenormalize stores selected weights in bf16 internally. + ref_weights = ref_weights.to(torch.bfloat16).to(torch.float32) + ref_ids = ref_ids.to(torch.int32) + + sort_idx = torch.argsort(topk_ids, dim=-1) + sorted_ids = torch.gather(topk_ids, dim=-1, index=sort_idx) + sorted_weights = torch.gather(topk_weights, dim=-1, index=sort_idx) + + ref_sort_idx = torch.argsort(ref_ids, dim=-1) + ref_sorted_ids = torch.gather(ref_ids, dim=-1, index=ref_sort_idx) + ref_sorted_weights = torch.gather(ref_weights, dim=-1, index=ref_sort_idx) + + assert torch.equal(sorted_ids, ref_sorted_ids) + torch.testing.assert_close(sorted_weights, ref_sorted_weights, rtol=1e-2, atol=1e-3) + + # ===================== Fused TopK Transform Tests =====================