Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions csrc/moe_utils_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,63 @@ void moe_sort(
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_sort, moe_sort);

#ifdef ENABLE_BF16
// ============================ fused_topk_raw_logits bindings ============================
// Use TRTLLM routingRenormalize code path (same kernel family as non-routed TRTLLM MoE).
//
// Notes:
// - This function computes top-k scores/weights from raw routing logits (mPtrScores path).
// - mPtrTopKIds is intentionally left nullptr to force score-driven routing selection.
// - topk_packed_ptr is required by routing kernels for large-token paths.
void fused_topk_raw_logits_trtllm_renormalize(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concern about the API name, can we make it simpler? (and why do we call it fused?)

int64_t topk_weights_ptr, int64_t topk_packed_ptr, int64_t gating_output_ptr,
int32_t num_tokens, int32_t num_experts, int32_t top_k, bool renormalize, bool use_pdl,
int32_t tile_tokens_dim, int64_t expert_counts_ptr, int64_t permuted_idx_size_ptr,
int64_t expanded_idx_to_permuted_idx_ptr, int64_t permuted_idx_to_token_idx_ptr,
int64_t cta_idx_to_batch_idx_ptr, int64_t cta_idx_to_mn_limit_ptr,
int64_t num_non_exiting_ctas_ptr,
// Optional explicit CUDA stream pointer for CUDA graph compatibility.
// If 0, use TVM FFI current stream.
int64_t cuda_stream_ptr) {
moe::dev::routing::routingRenormalize::Data routingData;

// Match TRTLLM non-routed MoE routing defaults.
routingData.mDtypeExpW = batchedGemm::trtllm::gen::Dtype::Bfloat16;
routingData.mDtypeElt = batchedGemm::trtllm::gen::Dtype::Bfloat16;
routingData.mUsePdl = use_pdl;
routingData.mDoSoftmaxBeforeTopK = false;
routingData.mNormTopkProb = false;
routingData.mApplySoftmaxAfterTopK = renormalize;

routingData.mPtrScores = reinterpret_cast<void const*>(gating_output_ptr);
routingData.mPtrTopKWeights = reinterpret_cast<void*>(topk_weights_ptr);
routingData.mPtrTopKIds = nullptr;
routingData.mPtrTopKPacked = reinterpret_cast<void*>(topk_packed_ptr);

routingData.mPtrExpertCounts = reinterpret_cast<int32_t*>(expert_counts_ptr);
routingData.mPtrPermutedIdxSize = reinterpret_cast<int32_t*>(permuted_idx_size_ptr);
routingData.mPtrExpandedIdxToPermutedIdx =
reinterpret_cast<int32_t*>(expanded_idx_to_permuted_idx_ptr);
routingData.mPtrPermutedIdxToTokenIdx = reinterpret_cast<int32_t*>(permuted_idx_to_token_idx_ptr);
routingData.mPtrCtaIdxXyToBatchIdx = reinterpret_cast<int32_t*>(cta_idx_to_batch_idx_ptr);
routingData.mPtrCtaIdxXyToMnLimit = reinterpret_cast<int32_t*>(cta_idx_to_mn_limit_ptr);
routingData.mPtrNumNonExitingCtas = reinterpret_cast<int32_t*>(num_non_exiting_ctas_ptr);

routingData.mNumTokens = num_tokens;
routingData.mNumExperts = num_experts;
routingData.mTopK = top_k;
routingData.mPaddingLog2 = computeLog2(tile_tokens_dim);
routingData.mTileTokensDim = tile_tokens_dim;
routingData.mLocalExpertsStartIdx = 0;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = num_experts;

cudaStream_t stream =
cuda_stream_ptr != 0 ? reinterpret_cast<cudaStream_t>(cuda_stream_ptr) : get_current_stream();
moe::dev::routing::routingRenormalize::run(routingData, stream);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_fused_topk_raw_logits_trtllm_renormalize,
fused_topk_raw_logits_trtllm_renormalize);
#endif
4 changes: 4 additions & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from .fused_routing_dsv3 import ( # noqa: F401
fused_topk_deepseek as fused_topk_deepseek,
)
from .raw_logits_topk import (
fused_topk_raw_logits as fused_topk_raw_logits,
)

# CuteDSL MoE APIs (conditionally imported if cute_dsl available)
try:
Expand Down Expand Up @@ -74,6 +77,7 @@
"trtllm_fp8_per_tensor_scale_moe",
"trtllm_mxint4_block_scale_moe",
"fused_topk_deepseek",
"fused_topk_raw_logits",
]

# Add CuteDSL exports if available
Expand Down
289 changes: 289 additions & 0 deletions flashinfer/fused_moe/raw_logits_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import annotations

import functools

import torch

from flashinfer.api_logging import flashinfer_api
from flashinfer.jit.moe_utils import gen_moe_utils_module
from flashinfer.utils import device_support_pdl

_ROUTING_TILE_TOKENS_DIM = 128
_workspace_pool: dict[tuple[str, int | None], "_RawLogitsTopkWorkspace"] = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The _workspace_pool global dictionary and the _get_workspace function implement a shared workspace pool that is not thread-safe or stream-safe. The _RawLogitsTopkWorkspace instance is shared across all threads and CUDA streams using the same GPU device. This workspace contains pre-allocated tensors (e.g., topk_weights_bf16, topk_packed, expert_counts) used as intermediate buffers and outputs by the CUDA kernels. Concurrent calls to fused_topk_raw_logits from different threads or streams on the same device will result in multiple kernels writing to the same memory locations simultaneously. This leads to data corruption and potential information leakage between different requests or users in a multi-tenant environment (e.g., an LLM serving platform). To remediate this, consider making the workspace pool key include the CUDA stream ID, or using thread-local storage, or implementing a locking mechanism to ensure exclusive access to the workspace during kernel execution.



@functools.lru_cache(maxsize=1)
def _get_moe_utils_module():
spec = gen_moe_utils_module()
return spec.build_and_load()


def _get_cuda_stream_ptr(device: torch.device) -> int:
return torch.cuda.current_stream(device=device).cuda_stream


class _RawLogitsTopkWorkspace:
def __init__(self, device: torch.device) -> None:
self.device = device
self.topk_weights_bf16 = torch.empty((0,), dtype=torch.bfloat16, device=device)
self.topk_packed = torch.empty((0,), dtype=torch.int32, device=device)
self.expert_counts = torch.empty((0,), dtype=torch.int32, device=device)
self.permuted_idx_size = torch.empty((0,), dtype=torch.int32, device=device)
self.expanded_idx_to_permuted_idx = torch.empty(
(0,), dtype=torch.int32, device=device
)
self.permuted_idx_to_token_idx = torch.empty((0,), dtype=torch.int32, device=device)
self.cta_idx_to_batch_idx = torch.empty((0,), dtype=torch.int32, device=device)
self.cta_idx_to_mn_limit = torch.empty((0,), dtype=torch.int32, device=device)
self.num_non_exiting_ctas = torch.empty((0,), dtype=torch.int32, device=device)

@staticmethod
def _ensure_capacity(tensor: torch.Tensor, numel: int) -> torch.Tensor:
if tensor.numel() >= numel:
return tensor
return torch.empty((numel,), dtype=tensor.dtype, device=tensor.device)

def get_views(
self,
num_tokens: int,
topk: int,
num_experts: int,
max_num_tiles: int,
max_num_permuted_tokens: int,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
expanded_size = num_tokens * topk
expert_counts_size = max(512, 2 * num_experts)

self.topk_weights_bf16 = self._ensure_capacity(self.topk_weights_bf16, expanded_size)
self.topk_packed = self._ensure_capacity(self.topk_packed, expanded_size)
self.expert_counts = self._ensure_capacity(self.expert_counts, expert_counts_size)
self.permuted_idx_size = self._ensure_capacity(self.permuted_idx_size, 1)
self.expanded_idx_to_permuted_idx = self._ensure_capacity(
self.expanded_idx_to_permuted_idx, expanded_size
)
self.permuted_idx_to_token_idx = self._ensure_capacity(
self.permuted_idx_to_token_idx, max_num_permuted_tokens
)
self.cta_idx_to_batch_idx = self._ensure_capacity(
self.cta_idx_to_batch_idx, max_num_tiles
)
self.cta_idx_to_mn_limit = self._ensure_capacity(self.cta_idx_to_mn_limit, max_num_tiles)
self.num_non_exiting_ctas = self._ensure_capacity(self.num_non_exiting_ctas, 1)

return (
self.topk_weights_bf16[:expanded_size].view(num_tokens, topk),
self.topk_packed[:expanded_size].view(num_tokens, topk),
self.expert_counts[:expert_counts_size],
self.permuted_idx_size[:1],
self.expanded_idx_to_permuted_idx[:expanded_size],
self.permuted_idx_to_token_idx[:max_num_permuted_tokens],
self.cta_idx_to_batch_idx[:max_num_tiles],
self.cta_idx_to_mn_limit[:max_num_tiles],
self.num_non_exiting_ctas[:1],
)


def _get_workspace(device: torch.device) -> _RawLogitsTopkWorkspace:
key = (device.type, device.index)
ws = _workspace_pool.get(key)
if ws is None:
ws = _RawLogitsTopkWorkspace(device)
_workspace_pool[key] = ws
return ws
Comment on lines +112 to +118
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _get_workspace is not thread-safe because it accesses the global _workspace_pool without synchronization. This can lead to a race condition where multiple threads create a workspace for the same device simultaneously, causing unpredictable behavior. To fix this, a lock should be used to protect the creation and insertion of new workspaces into the pool.

Suggested change
def _get_workspace(device: torch.device) -> _RawLogitsTopkWorkspace:
key = (device.type, device.index)
ws = _workspace_pool.get(key)
if ws is None:
ws = _RawLogitsTopkWorkspace(device)
_workspace_pool[key] = ws
return ws
def _get_workspace(device: torch.device) -> _RawLogitsTopkWorkspace:
if not hasattr(_get_workspace, "lock"):
import threading
_get_workspace.lock = threading.Lock()
key = (device.type, device.index)
ws = _workspace_pool.get(key)
if ws is not None:
return ws
with _get_workspace.lock:
ws = _workspace_pool.get(key)
if ws is None:
ws = _RawLogitsTopkWorkspace(device)
_workspace_pool[key] = ws
return ws



def _get_max_num_tiles(
num_tokens: int,
top_k: int,
num_local_experts: int,
tile_size: int,
) -> int:
# Mirrors TRTLLM GroupedGemmInputsHelper.get_max_num_tiles.
num_expanded_tokens = num_tokens * top_k
if num_expanded_tokens <= num_local_experts:
return num_expanded_tokens
num_remaining_tokens = num_expanded_tokens - num_local_experts
return num_local_experts + (num_remaining_tokens + tile_size - 1) // tile_size


def _validate_args(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
if topk_weights.ndim != 2 or topk_ids.ndim != 2 or gating_output.ndim != 2:
raise ValueError(
"Expected 2D tensors for topk_weights, topk_ids, and gating_output."
)
if topk_weights.shape != topk_ids.shape:
raise ValueError(
f"topk_weights/topk_ids shape mismatch: {topk_weights.shape} vs {topk_ids.shape}"
)
if topk_weights.shape[0] != gating_output.shape[0]:
raise ValueError(
"Batch size mismatch: "
f"{topk_weights.shape[0]} (output) vs {gating_output.shape[0]} (gating_output)"
)
if topk_weights.dtype != torch.float32:
raise ValueError(
f"Expected topk_weights dtype float32, got {topk_weights.dtype}"
)
if not topk_weights.is_contiguous():
raise ValueError("Expected topk_weights to be contiguous.")
if topk_ids.dtype not in (torch.int32, torch.int64):
raise ValueError(
f"Expected topk_ids dtype int32 or int64, got {topk_ids.dtype}"
)
if not topk_ids.is_contiguous():
raise ValueError("Expected topk_ids to be contiguous.")
if gating_output.dtype != torch.bfloat16:
raise ValueError(
"TRTLLM routingRenormalize path expects bf16 gating_output, got "
f"{gating_output.dtype}"
)
if (
topk_weights.device != gating_output.device
or topk_ids.device != gating_output.device
):
raise ValueError(
"topk_weights, topk_ids, and gating_output must be on the same device."
)
if gating_output.device.type != "cuda":
raise ValueError("TRTLLM routingRenormalize path only supports CUDA tensors.")
if not isinstance(renormalize, bool):
raise ValueError(f"renormalize must be bool, got {type(renormalize)}")

topk = topk_weights.shape[1]
num_experts = gating_output.shape[1]
if topk < 1:
raise ValueError(f"Invalid top-k: {topk}")
if topk > num_experts:
raise ValueError(f"Invalid top-k {topk} for num_experts={num_experts}")
if topk > 10:
raise ValueError(f"TRTLLM routingRenormalize supports top-k <= 10, got {topk}")
if num_experts > 512:
raise ValueError(
f"TRTLLM routingRenormalize supports num_experts <= 512, got {num_experts}"
)
if num_experts % 4 != 0:
raise ValueError(
f"TRTLLM routingRenormalize expects num_experts % 4 == 0, got {num_experts}"
)


@flashinfer_api
def fused_topk_raw_logits(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = True,
) -> None:
"""TRTLLM routingRenormalize raw-logits top-k.

This API intentionally uses only TRTLLM's routingRenormalize path. There is no
fallback path in this implementation.

Supported configuration:
- ``gating_output`` dtype ``torch.bfloat16``
- routing score mode fixed to raw-logits + optional post-topk softmax
- routing tile size fixed to 128
"""

_validate_args(
topk_weights=topk_weights,
topk_ids=topk_ids,
gating_output=gating_output,
renormalize=renormalize,
)

num_tokens, num_experts = gating_output.shape
topk = topk_weights.shape[1]
max_num_tiles = _get_max_num_tiles(
num_tokens=num_tokens,
top_k=topk,
num_local_experts=num_experts,
tile_size=_ROUTING_TILE_TOKENS_DIM,
)
max_num_permuted_tokens = max_num_tiles * _ROUTING_TILE_TOKENS_DIM
use_pdl = device_support_pdl(gating_output.device)

(
topk_weights_bf16,
topk_packed,
expert_counts,
permuted_idx_size,
expanded_idx_to_permuted_idx,
permuted_idx_to_token_idx,
cta_idx_to_batch_idx,
cta_idx_to_mn_limit,
num_non_exiting_ctas,
) = _get_workspace(gating_output.device).get_views(
num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
max_num_tiles=max_num_tiles,
max_num_permuted_tokens=max_num_permuted_tokens,
)

if not gating_output.is_contiguous():
gating_output = gating_output.contiguous()
# Initialize tail entries so searchsorted sees a monotonic full-length array.
cta_idx_to_batch_idx.zero_()
cta_idx_to_mn_limit.fill_(torch.iinfo(torch.int32).max)
_get_moe_utils_module()["flashinfer_fused_topk_raw_logits_trtllm_renormalize"](
topk_weights_bf16.data_ptr(),
topk_packed.data_ptr(),
gating_output.data_ptr(),
num_tokens,
num_experts,
topk,
renormalize,
use_pdl,
_ROUTING_TILE_TOKENS_DIM,
expert_counts.data_ptr(),
permuted_idx_size.data_ptr(),
expanded_idx_to_permuted_idx.data_ptr(),
permuted_idx_to_token_idx.data_ptr(),
cta_idx_to_batch_idx.data_ptr(),
cta_idx_to_mn_limit.data_ptr(),
num_non_exiting_ctas.data_ptr(),
_get_cuda_stream_ptr(gating_output.device),
)

# Recover expert ids entirely on-device from routing metadata. We intentionally
# avoid host syncs (e.g. tensor.item()) to keep this CUDA-graph safe.
expanded = expanded_idx_to_permuted_idx.view(num_tokens, topk)
cta_idx = torch.searchsorted(cta_idx_to_mn_limit, expanded, right=True)
topk_ids_i32 = cta_idx_to_batch_idx[cta_idx].to(torch.int32)
topk_weights.copy_(topk_weights_bf16.float())
if topk_ids.dtype == torch.int32:
topk_ids.copy_(topk_ids_i32)
else:
topk_ids.copy_(topk_ids_i32.to(torch.int64))
2 changes: 2 additions & 0 deletions flashinfer/jit/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def gen_moe_utils_module() -> JitSpec:
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu",
# Routing kernels for moe_sort
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_deepseek.cu",
# Routing kernels for raw-logits top-k (TRTLLM renormalize path)
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_renormalize.cu",
],
extra_cuda_cflags=nvcc_flags,
extra_include_paths=[
Expand Down
Loading