Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/moe_utils_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ void moe_sort(
routingData.mPtrCtaIdxXyToMnLimit = reinterpret_cast<int32_t*>(tile_idx_to_mn_limit_ptr);
routingData.mPtrExpandedIdxToPermutedIdx =
reinterpret_cast<int32_t*>(expanded_idx_to_permuted_idx_ptr);
routingData.mPtrPermutedIdxToTokenIdx =
routingData.mPtrPermutedIdxToExpandedIdx =
reinterpret_cast<int32_t*>(permuted_idx_to_expanded_idx_ptr);
routingData.mPtrPermutedIdxToTokenIdx = nullptr;
routingData.mPtrPermutedIdxSize = reinterpret_cast<int32_t*>(total_num_padded_tokens_ptr);
routingData.mPtrNumNonExitingCtas = reinterpret_cast<int32_t*>(num_non_exiting_tiles_ptr);

Expand Down
6 changes: 5 additions & 1 deletion csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,9 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
if (params.mPtrExpandedIdxToPermutedIdx != nullptr) {
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
}
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) {
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
}
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) {
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
Expand Down Expand Up @@ -549,7 +552,8 @@ void runImpl(Data& data, void* stream) {
"When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for "
"DeepSeek routing.");
}
if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr)
if (data.mPtrExpandedIdxToPermutedIdx != nullptr ||
data.mPtrPermutedIdxToExpandedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr)
FLASHINFER_CHECK(
(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize,
"If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required");
Expand Down
4 changes: 4 additions & 0 deletions csrc/trtllm_fused_moe_routing_llama4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted) {
params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx;
}
// write out `mPtrPermutedIdxToExpandedIdx` if required
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) {
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For clarity and consistency with other routing kernels, it would be better to make it explicit that tokenIdx is being used as expandedIdx. While expandedIdx is equivalent to tokenIdx in this kernel (since topK=1), this is an important implementation detail. Adding an inline comment would help future maintainers understand the code's intent more easily, especially when comparing with other routing kernels that use an expandedIdx variable.

        params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx; // For llama4 routing, topK=1, so expandedIdx is equivalent to tokenIdx.

}
// write out `mPtrPermutedIdxToTokenIdx` if required
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted) {
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
Expand Down
3 changes: 3 additions & 0 deletions csrc/trtllm_fused_moe_routing_renormalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)

params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
if (isLocalExpert) {
if (params.mPtrPermutedIdxToExpandedIdx != nullptr) {
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
}
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
}
Expand Down
3 changes: 3 additions & 0 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrTopKWeights = expertWeights;

Expand Down Expand Up @@ -113,6 +114,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrTopKWeights = expertWeights;

Expand Down Expand Up @@ -156,6 +158,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrTopKWeights = expertWeights;

Expand Down
6 changes: 6 additions & 0 deletions include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ __device__ void routingPermutation(KernelParams params,
if (params.mPtrExpandedIdxToPermutedIdx != nullptr) {
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
}
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) {
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
}
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) {
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
Expand Down Expand Up @@ -729,6 +732,9 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
if (params.mPtrExpandedIdxToPermutedIdx != nullptr) {
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
}
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) {
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
}
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) {
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
Expand Down
5 changes: 5 additions & 0 deletions include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ struct DataBase {
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mTileTokensDim * mTopK + (mNumExperts Γ— mTileTokensDim) - mNumExperts]
int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mTileTokensDim * mTopK + (mNumExperts Γ— mTileTokensDim) - mNumExperts]
// Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized
// Any out-of-bounds values are undefined.
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
Expand Down Expand Up @@ -113,6 +116,7 @@ struct KernelParamsBase {
int32_t* mPtrExpertCounts = nullptr;
int32_t* mPtrPermutedIdxSize = nullptr;
int32_t* mPtrExpandedIdxToPermutedIdx = nullptr;
int32_t* mPtrPermutedIdxToExpandedIdx = nullptr;
int32_t* mPtrPermutedIdxToTokenIdx = nullptr;
int32_t* mPtrCtaIdxXyToBatchIdx = nullptr;
int32_t* mPtrCtaIdxXyToMnLimit = nullptr;
Expand All @@ -137,6 +141,7 @@ struct KernelParamsBase {
mPtrExpertCounts = data.mPtrExpertCounts;
mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
mPtrPermutedIdxToExpandedIdx = data.mPtrPermutedIdxToExpandedIdx;
mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
Expand Down
12 changes: 8 additions & 4 deletions tests/moe/test_cute_dsl_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,19 @@ def create_moe_tensors(


def check_accuracy(
actual: torch.Tensor, expected: torch.Tensor, percent_threshold: float = 0.925
actual: torch.Tensor, expected: torch.Tensor, percent_threshold: float = 0.97
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason we make this change?

):
"""Check numerical accuracy with percentage-based tolerance."""
"""Check numerical accuracy with percentage-based tolerance.

Tolerances are scaled by output magnitude to account for FP4 quantization
noise growing with larger hidden dimensions.
"""
actual = actual.float()
expected = expected.float()

output_scale = max(expected.std().item(), 0.01)
atol = max(0.1, 3.0 * output_scale)
rtol = 0.85
atol = max(0.05, 1.5 * output_scale)
rtol = 0.5

abs_diff = torch.abs(actual - expected)
rel_diff = abs_diff / (torch.abs(expected) + 1e-8)
Expand Down
Loading