From 10d80fca0a5309dc0a0c12d4de0b7e37ba318533 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 02:31:05 -0700 Subject: [PATCH 01/20] feat: add routing_replay_out to MoE kernel launchers and routing kernels --- .../trtllm_fused_moe_routing_custom.cu | 6 + .../trtllm_fused_moe_routing_deepseek.cu | 6 + .../trtllm_fused_moe_routing_llama4.cu | 5 + csrc/trtllm_fused_moe_kernel_launcher.cu | 106 ++++++++++++++++-- csrc/trtllm_fused_moe_runner.cu | 8 +- .../trtllm/fused_moe/RoutingKernel.h | 7 ++ include/flashinfer/trtllm/fused_moe/runner.h | 3 +- 7 files changed, 131 insertions(+), 10 deletions(-) diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu index 2cc618ed9a..6ba5d42660 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu @@ -419,6 +419,12 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; + + // Routing replay: record all top-K selected expert IDs per token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[tokenIdx * params.mTopK + laneIdx] = + static_cast(warpTopKExpertIdx[laneIdx]); + } } } diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu index a77c195cf9..e5e3d3e8f1 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu @@ -345,6 +345,12 @@ __global__ void routingMainKernel(KernelParams params) { params.mPtrTopKIds == nullptr) { params.mPtrTopKWeights[idxTopK] = finalScore; } + + // Routing replay: record all top-K selected expert IDs per token. + // Layout: [num_tokens, topK] -- same indexing as mPtrTopKPacked. + if (laneIdx < params.mTopK && params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[idxTopK] = static_cast(expertIdx); + } } } diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu index a7c0bed8bd..4fa0c1b5c9 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu @@ -464,6 +464,11 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; params.mPtrTopKPacked[tokenIdx] = packedScore; + + // Routing replay: record selected expert ID for this token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[tokenIdx] = static_cast(warpMaxExpertIdx[0]); + } } } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 7c5826802e..4046fb54eb 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -186,6 +186,9 @@ class FusedMoeLauncher { ActivationType activation_type{ActivationType::Swiglu}; btg::Dtype mDtypeScore{btg::Dtype::Bfloat16}; + // Optional routing replay output: [num_tokens, top_k] int16 tensor + Optional routing_replay_out; + int64_t intermediate_size_factor{2}; public: @@ -222,6 +225,10 @@ class FusedMoeLauncher { int64_t weight_layout, ActivationType activation_type, bool norm_topk_prob = true); + void set_routing_replay_out(const Optional& replay_out) { + routing_replay_out = replay_out; + } + // Routing logits [num_tokens, num_experts] void check_routing_logits() const { if (routing_logits.has_value()) { @@ -456,6 +463,11 @@ class FusedMoeLauncher { tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t routing_stream = get_stream(hidden_states.device()); + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + replay_ptr = reinterpret_cast(routing_replay_out.value().data_ptr()); + } + routing_runner.run( args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, @@ -471,7 +483,8 @@ class FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob); + norm_topk_prob, + replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1190,6 +1203,11 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; // When using pre-computed routing, pass nullptr as routing_logits to tell the // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + replay_ptr = reinterpret_cast(routing_replay_out.value().data_ptr()); + } + routing_runner.run( use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, @@ -1205,7 +1223,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob); + norm_topk_prob, + replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1672,6 +1691,11 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t routing_stream = get_stream(hidden_states.device()); + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + replay_ptr = reinterpret_cast(routing_replay_out.value().data_ptr()); + } + routing_runner.run( args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, @@ -1687,7 +1711,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob); + norm_topk_prob, + replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1742,7 +1767,8 @@ Array trtllm_bf16_moe(Optional const& routing_logits, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array moe_tactic, bool norm_topk_prob) { + bool enable_pdl, Array moe_tactic, bool norm_topk_prob, + Optional routing_replay_out) { // Just some basic type validation first and leave more checks to the launcher if (routing_logits.has_value()) { TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || @@ -1756,6 +1782,17 @@ Array trtllm_bf16_moe(Optional const& routing_logits, TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_bfloat16) << "BF16 MoE: gemm2_weights must be bfloat16."; + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + } + auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); @@ -1793,6 +1830,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, gemm2_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1818,7 +1856,8 @@ Array trtllm_fp8_per_tensor_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, bool use_routing_scales_on_input, int64_t routing_method_type, bool do_finalize, - bool enable_pdl, Array config_index, int64_t activation_type, bool norm_topk_prob) { + bool enable_pdl, Array config_index, int64_t activation_type, bool norm_topk_prob, + Optional routing_replay_out) { // Basic type validation auto dtype = hidden_states.dtype(); auto activation = static_cast(activation_type); @@ -1841,6 +1880,17 @@ Array trtllm_fp8_per_tensor_scale_moe( TVM_FFI_ICHECK_EQ(output2_scales_scalar.dtype(), dl_float32) << "FP8 MoE: output2_scales_scalar must be float32."; + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + } + auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); @@ -1880,6 +1930,7 @@ Array trtllm_fp8_per_tensor_scale_moe( output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, use_routing_scales_on_input, activation, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1906,7 +1957,8 @@ Array trtllm_fp8_block_scale_moe( int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl, Array config_index, Fp8QuantizationType quantization_type, - int64_t act_type, bool norm_topk_prob) { + int64_t act_type, bool norm_topk_prob, + Optional routing_replay_out) { auto activation_type = validateAndCastActivationType(act_type); // DeepSeekFp8 currently uses a TRTLLM runner that hardwires Swiglu activation semantics. // Fail for any other activation to avoid silently running incorrect activation behavior. @@ -1971,6 +2023,17 @@ Array trtllm_fp8_block_scale_moe( TVM_FFI_ICHECK(weight_layout == 0) << "weight_layout must be 0 for MxFp8."; } + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + } + auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); @@ -2005,6 +2068,7 @@ Array trtllm_fp8_block_scale_moe( quantization_type); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, activation_type, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -2037,7 +2101,8 @@ Array trtllm_fp4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t act_type, - TensorView output, Array config_index, bool norm_topk_prob) { + TensorView output, Array config_index, bool norm_topk_prob, + Optional routing_replay_out) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); @@ -2079,6 +2144,17 @@ Array trtllm_fp4_block_scale_moe( << "routing_bias has incorrect shape."; } + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + } + // Determine activation type TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) << "weights must be fp4 packed in uint8."; @@ -2145,6 +2221,7 @@ Array trtllm_fp4_block_scale_moe( launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, /*weight_layout=*/0, static_cast(act_type), mDtypeAct, mDtypeWeights, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -2170,7 +2247,8 @@ Array trtllm_mxint4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, TensorView output, - Array config_index, bool norm_topk_prob) { + Array config_index, bool norm_topk_prob, + Optional routing_replay_out) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); @@ -2192,6 +2270,17 @@ Array trtllm_mxint4_block_scale_moe( << "routing_bias has incorrect shape."; } + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + } + // Determine activation type TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) << "weights must be int4 packed in uint8."; @@ -2229,6 +2318,7 @@ Array trtllm_mxint4_block_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_weights_scale, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale); launcher->init(std::move(args), curr_tile_N, routing_method_type, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 40150ad86d..12a7f23b7c 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -69,7 +69,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, btg::Dtype dtypeLogits, - bool normTopkProb) { + bool normTopkProb, + int16_t* routing_replay_out) { if (routingMethodType == RoutingMethodType::DeepSeekV3 && nGroup <= 1) { // DeepSeek no-groups case: use routingCustom with SigmoidBias preprocess // and ScaledSumNormalize postprocess. This is more efficient than the full DeepSeek @@ -108,6 +109,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingCustom::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::MiniMax2) { @@ -149,6 +151,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingCustom::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::DeepSeekV3) { @@ -191,6 +194,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLocalExperts = localNumExperts; routingData.mRouteScale = routedScalingFactor; routingData.mUseRoutingSoftmax = false; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingDeepSeek::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Llama4) { FLASHINFER_CHECK(topK == 1, "For Llama routing method, must have topK == 1"); @@ -228,6 +232,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingLlama4::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Default /* Softmax -> TopK */ || routingMethodType == RoutingMethodType::Renormalize /* TopK -> Softmax */ @@ -305,6 +310,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; routingCustom::run(routingData, stream); } else { diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index c26c515781..90154c106f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -84,6 +84,11 @@ struct DataBase { // Together with mPtrTopKWeights, they form the top-k experts for each token int32_t* mPtrTopKIds{nullptr}; + // optional: if nullptr, no routing replay recording occurs + // dim: [mNumTokens, mTopK] + // Records the selected expert IDs per token for replay + int16_t* mPtrRoutingReplayOut{nullptr}; + // optional: if `nullptr`, scores are used directly as input. // If it is given, it must represent a packed value s.t. the most significant // 16/32 bits represent the score without sigmoid activation and @@ -148,6 +153,7 @@ struct KernelParamsBase { OutputT* mPtrTopKWeights = nullptr; int32_t* mPtrTopKIds = nullptr; InputT const* mPtrScores = nullptr; + int16_t* mPtrRoutingReplayOut = nullptr; // Public scalar members int32_t mNumTokens = 0; @@ -177,6 +183,7 @@ struct KernelParamsBase { mPtrTopKWeights = static_cast(data.mPtrTopKWeights); mPtrTopKIds = static_cast(data.mPtrTopKIds); mPtrScores = (InputT const*)data.mPtrScores; + mPtrRoutingReplayOut = data.mPtrRoutingReplayOut; mNumTokens = data.mNumTokens; mNumExperts = data.mNumExperts; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index df46aeed0b..2d2361ddd9 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -137,7 +137,8 @@ class Runner { batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, batchedGemm::trtllm::gen::Dtype dtypeLogits, - bool normTopkProb = true); + bool normTopkProb = true, + int16_t* routing_replay_out = nullptr); private: int32_t mTileTokensDim{8}; From b39aa06b24d7b8499c43d81f7785a36b913b1152 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 02:46:13 -0700 Subject: [PATCH 02/20] feat: add routing_replay_out to noAuxTc DSV3 routing kernel Add optional int16_t* routing_replay_out parameter to the standalone DSV3 fused routing kernel (noAuxTcKernels). When provided, writes selected expert IDs per token during routing. Includes input validation in the entry point. --- csrc/fused_moe/noAuxTcKernels.cu | 52 ++++++++++++++----- .../trtllm/fused_moe/noAuxTcKernels.h | 3 +- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/csrc/fused_moe/noAuxTcKernels.cu b/csrc/fused_moe/noAuxTcKernels.cu index 1f57d9b57b..dd30c05075 100644 --- a/csrc/fused_moe/noAuxTcKernels.cu +++ b/csrc/fused_moe/noAuxTcKernels.cu @@ -30,7 +30,8 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx int64_t const numGroup, int64_t const topkGroup, int64_t const topk, int64_t const numExperts, int64_t const numExpertsPerGroup, - double const routedScalingFactor) { + double const routedScalingFactor, + int16_t* routingReplayOut) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif @@ -212,6 +213,10 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx if (laneIdx < topk) { topkValues[laneIdx] = static_cast(finalScore); topkIndices[laneIdx] = expertIdx; + // Routing replay: record selected expert IDs per token + if (routingReplayOut != nullptr) { + routingReplayOut[blockIdx.x * topk + laneIdx] = static_cast(expertIdx); + } } } @@ -224,7 +229,8 @@ template void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, - bool const launch_with_pdl, cudaStream_t const stream) { + bool const launch_with_pdl, cudaStream_t const stream, + int16_t* routing_replay_out) { // Check if we can use the optimized deepseek_v3_topk_kernel bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts); @@ -262,7 +268,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, topk_indices, bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, - routed_scaling_factor); + routed_scaling_factor, routing_replay_out); sync_check_cuda_error(stream); } else { // TODO: call the generic path (previous implementation) or signal unsupported config. @@ -279,7 +285,8 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \ int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \ int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \ - bool const launch_with_pdl, cudaStream_t const stream); + bool const launch_with_pdl, cudaStream_t const stream, \ + int16_t* routing_replay_out); INSTANTIATE_NOAUX_TC(float, float, float, int32_t); INSTANTIATE_NOAUX_TC(float, half, float, int32_t); @@ -305,7 +312,7 @@ namespace flashinfer::trtllm_dsv3_fused_routing { void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_group, int64_t topk, double routed_scaling_factor, TensorView topk_values, TensorView topk_indices, - bool launch_with_pdl) { + bool launch_with_pdl, Optional routing_replay_out) { auto data_type = scores.dtype(); auto bias_type = bias.dtype(); @@ -342,6 +349,23 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) << "topk_indices must have the same dtype as scores"; + // Validate and extract routing_replay_out + constexpr int64_t int16_code_val = encode_dlpack_dtype(DLDataType{kDLInt, 16, 1}); + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == scores.device().device_id) + << "routing_replay_out must be on the same device as scores"; + TVM_FFI_ICHECK(replay.ndim() == 2) + << "routing_replay_out must be a 2D Tensor [num_tokens, topk]"; + TVM_FFI_ICHECK(replay.sizes()[1] == topk) << "routing_replay_out dim1 must equal topk"; + TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code_val) + << "routing_replay_out must be int16 dtype"; + replay_ptr = reinterpret_cast(replay.data_ptr()); + } + auto stream = get_stream(scores.device()); using namespace tensorrt_llm::kernels; switch (encode_dlpack_dtype(data_type)) { @@ -353,14 +377,14 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float32_code: invokeNoAuxTc( reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case bfloat16_code: invokeNoAuxTc( @@ -368,7 +392,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; default: throw std::invalid_argument( @@ -384,14 +408,14 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float16_code: invokeNoAuxTc( reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case bfloat16_code: invokeNoAuxTc( @@ -399,7 +423,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; default: throw std::invalid_argument( @@ -416,7 +440,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float16_code: invokeNoAuxTc<__nv_bfloat16, half, __nv_bfloat16, int32_t>( @@ -424,7 +448,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(bias.data_ptr()), reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float32_code: invokeNoAuxTc<__nv_bfloat16, float, __nv_bfloat16, int32_t>( @@ -432,7 +456,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(bias.data_ptr()), reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; default: throw std::invalid_argument( diff --git a/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h index 5af8fe39db..fa7f7f8339 100644 --- a/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h +++ b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h @@ -28,6 +28,7 @@ template void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, - cudaStream_t const stream = 0); + bool const launch_with_pdl, cudaStream_t const stream, + int16_t* routing_replay_out = nullptr); } // namespace tensorrt_llm::kernels From fd6ca4b048f6590fc2ad94650d6142db5354629e Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 02:58:04 -0700 Subject: [PATCH 03/20] feat: add routing_replay_out to Python MoE API Add optional routing_replay_out parameter to trtllm_fp8_block_scale_moe(), trtllm_bf16_moe(), trtllm_bf16_routed_moe(), and fused_topk_deepseek(). Relaxed dim0 validation (>= instead of ==) for CUDA graph compatibility. Thread through autotuner and launcher classes via kwargs. --- flashinfer/fused_moe/core.py | 22 +++++++++++++++++-- flashinfer/fused_moe/fused_routing_dsv3.py | 25 +++++++++++++++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index f7a1522333..b9f5e7ae85 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1108,6 +1108,7 @@ def forward( kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) elif ( self.dtype_act == DtypeTrtllmGen.E4m3 @@ -1168,6 +1169,7 @@ def forward( self.fp8_quantization_type, self.activation_type, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) else: # FP8 per tensor scale @@ -1266,7 +1268,7 @@ def forward( @register_custom_op( "flashinfer::trtllm_bf16_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_bf16_moe_op( routing_logits: Optional[torch.Tensor], @@ -1291,6 +1293,7 @@ def trtllm_bf16_moe_op( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: assert routing_logits is not None or topk_ids is not None, ( "either routing_logits or topk_ids must be provided" @@ -1402,6 +1405,7 @@ def trtllm_bf16_moe_op( enable_pdl, [-1, -1] if tactic == -1 else tactic, norm_topk_prob, + routing_replay_out, ) if do_finalize: return [output] @@ -1436,7 +1440,9 @@ def _fake_trtllm_bf16_moe( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1613,7 +1619,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( @register_custom_op( "flashinfer::trtllm_fp8_block_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_fp8_block_scale_moe_op( routing_logits: Optional[torch.Tensor], @@ -1644,6 +1650,7 @@ def trtllm_fp8_block_scale_moe_op( fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: # Determine routing mode: compute from logits or use pre-computed if routing_logits is None: @@ -1788,6 +1795,7 @@ def trtllm_fp8_block_scale_moe_op( fp8_quantization_type, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -1833,7 +1841,9 @@ def _fake_trtllm_fp8_block_scale_moe( fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -2016,6 +2026,7 @@ def trtllm_fp4_block_scale_moe_op( output, [-1, -1] if tactic == -1 else tactic, norm_topk_prob, + routing_replay_out, ) if do_finalize: return [output] @@ -2200,6 +2211,7 @@ def trtllm_mxint4_block_scale_moe_op( output, [-1, -1] if tactic == -1 else tactic, norm_topk_prob, + routing_replay_out, ) if do_finalize: return [output] @@ -2273,6 +2285,7 @@ def trtllm_bf16_moe( enable_pdl: bool = True, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2338,6 +2351,7 @@ def trtllm_bf16_moe( enable_pdl, tune_max_num_tokens, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -2369,6 +2383,7 @@ def trtllm_bf16_routed_moe( do_finalize: bool = True, enable_pdl: bool = True, tune_max_num_tokens: int = 8192, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2529,6 +2544,7 @@ def trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -2567,6 +2583,7 @@ def trtllm_fp8_block_scale_moe( fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 block scale MoE operation. @@ -2642,6 +2659,7 @@ def trtllm_fp8_block_scale_moe( fp8_quantization_type, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py index 5e26ca30cf..d80365284c 100644 --- a/flashinfer/fused_moe/fused_routing_dsv3.py +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -1,3 +1,5 @@ +from typing import Optional + from flashinfer.api_logging import flashinfer_api from flashinfer.jit import gen_dsv3_fused_routing_module import functools @@ -21,6 +23,7 @@ def _check_dsv3_fused_routing_supported( topk_values, topk_indices, launch_with_pdl, + routing_replay_out=None, ): """Validate configuration parameters for DSv3 fused routing kernel. @@ -38,6 +41,18 @@ def _check_dsv3_fused_routing_supported( Raises: ValueError: If configuration is invalid or exceeds kernel limits """ + if routing_replay_out is not None: + num_tokens = scores.shape[0] + if routing_replay_out.dtype != torch.int16: + raise ValueError( + f"routing_replay_out must be int16, got {routing_replay_out.dtype}" + ) + if routing_replay_out.shape[0] < num_tokens or routing_replay_out.shape[1] != topk: + raise ValueError( + f"routing_replay_out shape[0] must be >= {num_tokens} and shape[1] must be {topk}, " + f"got {tuple(routing_replay_out.shape)}" + ) + # Extract number of experts from scores shape num_experts = scores.shape[1] @@ -86,7 +101,7 @@ def get_dsv3_fused_routing_module(): @register_custom_op( "flashinfer::NoAuxTc", - mutates_args=["topk_values", "topk_indices"], + mutates_args=["topk_values", "topk_indices", "routing_replay_out"], ) def NoAuxTc( scores: torch.Tensor, @@ -98,6 +113,7 @@ def NoAuxTc( topk_values: torch.Tensor, topk_indices: torch.Tensor, launch_with_pdl: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> None: module.NoAuxTc( scores, @@ -109,6 +125,7 @@ def NoAuxTc( topk_values, topk_indices, launch_with_pdl, + routing_replay_out, ) return SimpleNamespace( @@ -128,6 +145,7 @@ def fused_topk_deepseek( topk_values: torch.Tensor, topk_indices: torch.Tensor, launch_with_pdl: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> None: """Fused expert routing with top-k selection for DeepSeek-V3. @@ -168,6 +186,10 @@ def fused_topk_deepseek( This tensor is mutated in-place. launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent Device-side Launch. Defaults to True. + routing_replay_out (torch.Tensor, optional): Pre-allocated output tensor of shape + (num_tokens, topk) with dtype int16 for recording the selected expert IDs per + token. If None, no routing replay recording occurs (zero overhead). This tensor + is mutated in-place. Returns: None: Results are written directly to `topk_values` and `topk_indices` tensors. @@ -193,4 +215,5 @@ def fused_topk_deepseek( topk_values, topk_indices, launch_with_pdl, + routing_replay_out, ) From f6bd8cb1e95d40641f5568e44380fc10298adc5f Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 03:09:19 -0700 Subject: [PATCH 04/20] test: add routing replay tests for FP8 and DSV3 MoE kernels Add test_routing_replay_out for DSV3 fused routing: verifies replay tensor matches topk_indices and passing None has no side effects. Add test_fp8_block_scale_moe_routing_replay for FP8 MoE: verifies replay has no effect on MoE output, expert IDs are valid, and each token has exactly top_k unique experts. --- .../test_dsv3_fused_routing.py | 179 ++++++++++++++++++ tests/moe/test_trtllm_gen_routed_fused_moe.py | 137 ++++++++++++++ 2 files changed, 316 insertions(+) diff --git a/tests/model_optimizations/test_dsv3_fused_routing.py b/tests/model_optimizations/test_dsv3_fused_routing.py index e84c9ca884..f42151de6c 100644 --- a/tests/model_optimizations/test_dsv3_fused_routing.py +++ b/tests/model_optimizations/test_dsv3_fused_routing.py @@ -499,3 +499,182 @@ def test_dsv3_fused_routing_op( # Validate values validate_values(ground_truth, sorted_vals, tokens_with_different_experts, data_type) + + +@pytest.mark.parametrize("num_tokens", [1, 8, 64]) +@pytest.mark.parametrize("num_experts", [256]) +@pytest.mark.parametrize("topk", [1, 4, 8]) +@pytest.mark.parametrize("n_group", [1, 8]) +@pytest.mark.parametrize("topk_group", [1, 4]) +@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float16]) +def test_routing_replay_out( + num_tokens, num_experts, topk, n_group, topk_group, data_type +): + """ + Test that routing_replay_out records the same expert IDs as topk_indices. + + The routing replay feature writes selected expert IDs (int16) into an + optional output tensor during the fused routing kernel. This test verifies + that routing_replay_out matches topk_indices (as sets per token), and that + passing None produces identical routing results (no side effects). + """ + if topk_group * n_group < topk or topk_group > n_group: + pytest.skip("Invalid configuration") + if n_group > 1: + if ( + topk > 8 + or num_experts / n_group > 32 + or num_experts / n_group * topk_group > 128 + ): + pytest.skip("Exceeds kernel limits for n_group > 1") + else: + if num_experts > 384 or topk > 8: + pytest.skip("Exceeds kernel limits for n_group = 1") + + torch.manual_seed(42) + device = "cuda" + scores = torch.randn(num_tokens, num_experts, device=device, dtype=data_type) + bias = torch.randn(num_experts, device=device, dtype=data_type) + routed_scaling_factor = 1.0 + + topk_values = torch.empty(num_tokens, topk, device=device, dtype=data_type) + topk_indices = torch.zeros(num_tokens, topk, device=device, dtype=torch.int32) + routing_replay_out = torch.full( + (num_tokens, topk), -1, device=device, dtype=torch.int16 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl=True, + routing_replay_out=routing_replay_out, + ) + + # routing_replay_out should contain the same expert IDs as topk_indices (per token) + for t in range(num_tokens): + replay_set = set(routing_replay_out[t].tolist()) + indices_set = set(topk_indices[t].tolist()) + assert replay_set == indices_set, ( + f"Token {t}: routing_replay_out experts {replay_set} " + f"!= topk_indices experts {indices_set}" + ) + + # Verify None produces identical results (no side effects from replay) + topk_values_no_replay = torch.empty( + num_tokens, topk, device=device, dtype=data_type + ) + topk_indices_no_replay = torch.zeros( + num_tokens, topk, device=device, dtype=torch.int32 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values_no_replay, + topk_indices_no_replay, + launch_with_pdl=True, + routing_replay_out=None, + ) + + torch.testing.assert_close(topk_values, topk_values_no_replay) + torch.testing.assert_close(topk_indices, topk_indices_no_replay) + + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32]) +@pytest.mark.parametrize("num_experts", [256]) +@pytest.mark.parametrize("topk", [1, 4, 8]) +@pytest.mark.parametrize("n_group", [1, 8]) +@pytest.mark.parametrize("topk_group", [1, 4]) +@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float16]) +def test_routing_replay_out( + num_tokens, num_experts, topk, n_group, topk_group, data_type +): + """ + Test that routing_replay_out records the same expert IDs as topk_indices. + + The routing replay feature writes selected expert IDs (int16) into an + optional output tensor during the fused routing kernel. This test verifies + that routing_replay_out matches topk_indices (as sets per token), and that + passing None produces identical routing results (no side effects). + """ + if topk_group * n_group < topk or topk_group > n_group: + pytest.skip("Invalid configuration") + if n_group > 1: + if ( + topk > 8 + or num_experts / n_group > 32 + or num_experts / n_group * topk_group > 128 + ): + pytest.skip("Exceeds kernel limits for n_group > 1") + else: + if num_experts > 384 or topk > 8: + pytest.skip("Exceeds kernel limits for n_group = 1") + + torch.manual_seed(42) + device = "cuda" + scores = torch.randn(num_tokens, num_experts, device=device, dtype=data_type) + bias = torch.randn(num_experts, device=device, dtype=data_type) + routed_scaling_factor = 1.0 + + topk_values = torch.empty(num_tokens, topk, device=device, dtype=data_type) + topk_indices = torch.zeros(num_tokens, topk, device=device, dtype=torch.int32) + routing_replay_out = torch.full( + (num_tokens, topk), -1, device=device, dtype=torch.int16 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl=True, + routing_replay_out=routing_replay_out, + ) + + # routing_replay_out should contain the same expert IDs as topk_indices (per token) + for t in range(num_tokens): + replay_set = set(routing_replay_out[t].tolist()) + indices_set = set(topk_indices[t].tolist()) + assert replay_set == indices_set, ( + f"Token {t}: routing_replay_out experts {replay_set} " + f"!= topk_indices experts {indices_set}" + ) + + # Verify None produces identical results (no side effects from replay) + topk_values_no_replay = torch.empty( + num_tokens, topk, device=device, dtype=data_type + ) + topk_indices_no_replay = torch.zeros( + num_tokens, topk, device=device, dtype=torch.int32 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values_no_replay, + topk_indices_no_replay, + launch_with_pdl=True, + routing_replay_out=None, + ) + + torch.testing.assert_close(topk_values, topk_values_no_replay) + torch.testing.assert_close(topk_indices, topk_indices_no_replay) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index f5fd5cc263..aea5993787 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -705,3 +705,140 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int): close = torch.isclose(output_ref, output_routed, atol=1e-2, rtol=1e-2) mismatch_pct = (~close).float().mean().item() * 100 assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" + + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("top_k", [2, 4]) +def test_fp8_block_scale_moe_routing_replay( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + num_experts: int, +): + """Test that routing_replay_out in trtllm_fp8_block_scale_moe records correct expert IDs. + + Runs the full MoE kernel twice with the same inputs: once with routing_replay_out + and once without. Verifies that: + 1. The MoE output is identical (replay has no side effects). + 2. The replay tensor contains valid expert IDs in [0, num_experts). + 3. Each token's replay IDs contain exactly top_k unique experts. + """ + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + torch.manual_seed(42) + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + + hidden_states_bf16 = ( + torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 + ) + hidden_states = hidden_states_bf16.to(torch.float8_e4m3fn) + + hidden_states_scale = torch.ones( + hidden_size // 128, num_tokens, device=device, dtype=torch.float32 + ) + + gemm1_weights = torch.randn( + num_experts, 2 * intermediate_size, hidden_size, device=device + ).to(torch.float8_e4m3fn) + gemm2_weights = torch.randn( + num_experts, hidden_size, intermediate_size, device=device + ).to(torch.float8_e4m3fn) + + gemm1_weights_scale = torch.ones( + num_experts, + 2 * intermediate_size // 128, + hidden_size // 128, + device=device, + dtype=torch.float32, + ) + gemm2_weights_scale = torch.ones( + num_experts, + hidden_size // 128, + intermediate_size // 128, + device=device, + dtype=torch.float32, + ) + + routing_replay_out = torch.full( + (num_tokens, top_k), -1, device=device, dtype=torch.int16 + ) + + output_with_replay = trtllm_fp8_block_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + None, # routed_scaling_factor + RoutingMethodType.Renormalize.value, + False, # use_shuffled_weight + 0, # weight_layout + enable_pdl, + routing_replay_out=routing_replay_out, + ) + + output_without_replay = trtllm_fp8_block_scale_moe( + routing_logits, + None, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + None, + None, + intermediate_size, + 0, + num_experts, + None, + RoutingMethodType.Renormalize.value, + False, + 0, + enable_pdl, + routing_replay_out=None, + ) + + # MoE output should be identical regardless of replay + torch.testing.assert_close( + output_with_replay.to(torch.float), + output_without_replay.to(torch.float), + rtol=0, + atol=0, + ) + + # All replay IDs should be valid expert indices + assert (routing_replay_out >= 0).all(), "Found negative expert IDs in replay" + assert (routing_replay_out < num_experts).all(), ( + f"Found expert IDs >= {num_experts} in replay" + ) + + # Each token should have top_k unique experts + for t in range(num_tokens): + unique_experts = routing_replay_out[t].unique() + assert unique_experts.numel() == top_k, ( + f"Token {t}: expected {top_k} unique experts, got {unique_experts.numel()}" + ) From 8d0d0f0f2e314027ed44f69b077c82382bb77911 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 03:10:22 -0700 Subject: [PATCH 05/20] docs: add vLLM routing replay integration guide --- docs/vllm_routing_replay_integration.md | 81 +++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 docs/vllm_routing_replay_integration.md diff --git a/docs/vllm_routing_replay_integration.md b/docs/vllm_routing_replay_integration.md new file mode 100644 index 0000000000..35390470b3 --- /dev/null +++ b/docs/vllm_routing_replay_integration.md @@ -0,0 +1,81 @@ +# vLLM Routing Replay Integration Guide + +## Overview + +FlashInfer supports an optional `routing_replay_out` parameter on its MoE kernel functions. +When provided, the CUDA routing kernel writes all top-K selected expert IDs per token directly +into this tensor during routing — inside the same fused kernel call that computes the MoE output. + +This enables **routing replay** for downstream RL training: vLLM captures which experts were +selected for each token during inference and returns them in the API response. + +## API + +### `routing_replay_out` Parameter + +Available on: +- `trtllm_fp8_block_scale_moe()` +- `trtllm_bf16_moe()` +- `trtllm_bf16_routed_moe()` +- `fused_topk_deepseek()` + +**Spec:** +``` +routing_replay_out: Optional[torch.Tensor] + dtype: torch.int16 + shape: (num_tokens_or_larger, top_k) + Layout: row-major. replay[t, k] = k-th ranked expert ID for token t + When None: zero overhead, the kernel skips the write entirely + When provided: the kernel writes expert IDs during routing +``` + +### CUDA Graph Compatibility + +The buffer may be **larger** than `num_tokens`. This is intentional: vLLM pre-allocates +the buffer at `max_num_batched_tokens` and reuses it across CUDA graph replays. The kernel +determines write extent from `routing_logits.shape[0]`, not from `routing_replay_out.shape[0]`. + +There is no strict `dim0 == num_tokens` validation — only `dim0 >= num_tokens` and +`dim1 == top_k`. + +### Memory Layout for vLLM Integration + +vLLM uses a device buffer with shape `(num_layers, max_num_batched_tokens, top_k)`: +- `buffer[layer_id]` gives a contiguous `(max_num_batched_tokens, top_k)` view +- This view is passed as `routing_replay_out` to the FlashInfer kernel +- The `(L, N, K)` layout ensures zero-copy per-layer slicing + +### Integration Pattern + +```python +# Pre-allocate once (during model initialization) +device_buffer = torch.zeros( + (num_layers, max_num_batched_tokens, top_k), + dtype=torch.int16, + device="cuda", +) + +# Per-layer forward pass +for layer_id, moe_layer in enumerate(moe_layers): + replay_slice = device_buffer[layer_id] # contiguous (N, K) view + output = trtllm_fp8_block_scale_moe( + ..., + routing_replay_out=replay_slice, + ) +``` + +### Validation + +```python +import torch + +# Allocate replay buffer +replay = torch.full((num_tokens, top_k), -1, device="cuda", dtype=torch.int16) + +# Run MoE +output = trtllm_fp8_block_scale_moe(..., routing_replay_out=replay) + +# Verify non-zero (not all -1 sentinel) +assert (replay != -1).any(), "Routing replay data is all sentinel values" +assert (replay >= 0).all() and (replay < num_experts).all(), "Invalid expert IDs" +``` From 88e18b485019a0526239c12ff50cc59749ae66ff Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 04:18:27 -0700 Subject: [PATCH 06/20] fix: add missing int16 dtype validation and FP4/MXINT4 public API params - Add int16 dtype check to all 5 launcher validation blocks (was missing, could silently corrupt memory on wrong dtype) - Add routing_replay_out param to trtllm_fp4_block_scale_moe(), trtllm_mxint4_block_scale_moe(), trtllm_fp8_per_tensor_scale_moe() public functions --- csrc/trtllm_fused_moe_kernel_launcher.cu | 10 ++++++++++ flashinfer/fused_moe/core.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 4046fb54eb..e200c6a023 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1790,6 +1790,8 @@ Array trtllm_bf16_moe(Optional const& routing_logits, << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -1888,6 +1890,8 @@ Array trtllm_fp8_per_tensor_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -2031,6 +2035,8 @@ Array trtllm_fp8_block_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -2152,6 +2158,8 @@ Array trtllm_fp4_block_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -2278,6 +2286,8 @@ Array trtllm_mxint4_block_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index b9f5e7ae85..ce5032b1fb 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -2484,6 +2484,8 @@ def trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + + routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 per tensor scale MoE operation. @@ -2818,6 +2820,7 @@ def trtllm_fp4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -3080,6 +3083,7 @@ def trtllm_mxint4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """MxInt4 block scale MoE operation. From 1fb845424964ac369ecfcca8f63f7358a53e2669 Mon Sep 17 00:00:00 2001 From: tbarnatan Date: Sun, 12 Apr 2026 15:12:33 +0300 Subject: [PATCH 07/20] =?UTF-8?q?fix:=20address=20code=20review=20feedback?= =?UTF-8?q?=20=E2=80=94=20missing=20routing=5Freplay=5Fout=20params,=20dup?= =?UTF-8?q?licate=20test,=20Llama4=20replay=20gaps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename duplicate test_routing_replay_out → test_routing_replay_out_extended (pytest collision) - Forward routing_replay_out in trtllm_bf16_routed_moe (was silently dropped) - Add routing_replay_out to FP4/MXINT4 op signatures, fake ops, and public APIs (NameError) - Move set_routing_replay_out() from protected to public in FusedMoeLauncher - Add routing replay writes to Llama4 warp and cluster kernel paths - Add explanatory comment for intentional dim0 validation omission (CUDA graph pre-alloc) - Apply pre-commit formatting fixes (clang-format, ruff) --- csrc/fused_moe/noAuxTcKernels.cu | 6 +++-- .../trtllm_fused_moe_routing_llama4.cu | 21 ++++++++++++++++ csrc/trtllm_fused_moe_kernel_launcher.cu | 24 ++++++++----------- csrc/trtllm_fused_moe_runner.cu | 3 +-- flashinfer/fused_moe/core.py | 14 ++++++++--- flashinfer/fused_moe/fused_routing_dsv3.py | 5 +++- include/flashinfer/trtllm/fused_moe/runner.h | 3 +-- .../test_dsv3_fused_routing.py | 8 ++----- tests/moe/test_trtllm_gen_routed_fused_moe.py | 1 - 9 files changed, 54 insertions(+), 31 deletions(-) diff --git a/csrc/fused_moe/noAuxTcKernels.cu b/csrc/fused_moe/noAuxTcKernels.cu index dd30c05075..82f886bdd6 100644 --- a/csrc/fused_moe/noAuxTcKernels.cu +++ b/csrc/fused_moe/noAuxTcKernels.cu @@ -285,8 +285,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \ int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \ int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \ - bool const launch_with_pdl, cudaStream_t const stream, \ - int16_t* routing_replay_out); + bool const launch_with_pdl, cudaStream_t const stream, int16_t* routing_replay_out); INSTANTIATE_NOAUX_TC(float, float, float, int32_t); INSTANTIATE_NOAUX_TC(float, half, float, int32_t); @@ -350,6 +349,9 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g << "topk_indices must have the same dtype as scores"; // Validate and extract routing_replay_out + // NOTE: dim0 >= num_tokens is intentionally NOT checked — with CUDA graphs the buffer + // is pre-allocated at maximum batch size and reused across steps with varying num_tokens. + // The kernel only writes to indices [0, num_tokens), so a larger buffer is safe. constexpr int64_t int16_code_val = encode_dlpack_dtype(DLDataType{kDLInt, 16, 1}); int16_t* replay_ptr = nullptr; if (routing_replay_out.has_value()) { diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu index 4fa0c1b5c9..55420baccc 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu @@ -132,6 +132,10 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam if (params.mPtrTopKWeights != nullptr) { params.mPtrTopKWeights[tokenIdx] = finalScore; } + // Routing replay: record selected expert ID for this token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[tokenIdx] = static_cast(warpMaxExpertIdx[0]); + } } } } else { @@ -162,6 +166,10 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam setBits(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread); if (threadIdx.x < params.mNumTokens) { smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount; + // Routing replay: record selected expert ID for this token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[threadIdx.x] = scoreIdx.idx; + } } } @@ -358,6 +366,11 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu TypePacked packedScore{static_cast(params.mPtrTopKWeights[warpTokenIdx]), static_cast(params.mPtrTopKIds[warpTokenIdx])}; smemPackedScoreIdx[warpIdx] = packedScore; + // Routing replay: record selected expert ID for this token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[warpTokenIdx] = + static_cast(params.mPtrTopKIds[warpTokenIdx]); + } } } else if (params.mPtrScores != nullptr) { // in this case, each warp represents a token @@ -374,11 +387,19 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; smemPackedScoreIdx[warpIdx] = packedScore; + // Routing replay: record selected expert ID for this token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[warpTokenIdx] = static_cast(warpMaxExpertIdx[0]); + } } } } else { if (validToken) { smemPackedScoreIdx[warpIdx] = params.mPtrTopKPacked[warpTokenIdx]; + // Routing replay: record selected expert ID for this token. + if (params.mPtrRoutingReplayOut != nullptr) { + params.mPtrRoutingReplayOut[warpTokenIdx] = params.mPtrTopKPacked[warpTokenIdx].idx; + } } } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index e200c6a023..c3a280367e 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -217,6 +217,11 @@ class FusedMoeLauncher { activation_type{ActivationType::Swiglu}, intermediate_size_factor{2} {} + public: + void set_routing_replay_out(const Optional& replay_out) { + routing_replay_out = replay_out; + } + protected: // Initialize common data necessary for later. // May throw exception from TVM_FFI_ICHECK. @@ -225,10 +230,6 @@ class FusedMoeLauncher { int64_t weight_layout, ActivationType activation_type, bool norm_topk_prob = true); - void set_routing_replay_out(const Optional& replay_out) { - routing_replay_out = replay_out; - } - // Routing logits [num_tokens, num_experts] void check_routing_logits() const { if (routing_logits.has_value()) { @@ -483,8 +484,7 @@ class FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob, - replay_ptr); + norm_topk_prob, replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1223,8 +1223,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob, - replay_ptr); + norm_topk_prob, replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1711,8 +1710,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob, - replay_ptr); + norm_topk_prob, replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1961,8 +1959,7 @@ Array trtllm_fp8_block_scale_moe( int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl, Array config_index, Fp8QuantizationType quantization_type, - int64_t act_type, bool norm_topk_prob, - Optional routing_replay_out) { + int64_t act_type, bool norm_topk_prob, Optional routing_replay_out) { auto activation_type = validateAndCastActivationType(act_type); // DeepSeekFp8 currently uses a TRTLLM runner that hardwires Swiglu activation semantics. // Fail for any other activation to avoid silently running incorrect activation behavior. @@ -2255,8 +2252,7 @@ Array trtllm_mxint4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, TensorView output, - Array config_index, bool norm_topk_prob, - Optional routing_replay_out) { + Array config_index, bool norm_topk_prob, Optional routing_replay_out) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 12a7f23b7c..43cb6e609e 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -69,8 +69,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, btg::Dtype dtypeLogits, - bool normTopkProb, - int16_t* routing_replay_out) { + bool normTopkProb, int16_t* routing_replay_out) { if (routingMethodType == RoutingMethodType::DeepSeekV3 && nGroup <= 1) { // DeepSeek no-groups case: use routingCustom with SigmoidBias preprocess // and ScaledSumNormalize postprocess. This is more efficient than the full DeepSeek diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index ce5032b1fb..56bc10e876 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1851,7 +1851,7 @@ def _fake_trtllm_fp8_block_scale_moe( @register_custom_op( "flashinfer::trtllm_fp4_block_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_fp4_block_scale_moe_op( routing_logits: Optional[torch.Tensor], @@ -1887,6 +1887,7 @@ def trtllm_fp4_block_scale_moe_op( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: if routing_logits is None: assert topk_ids is not None, ( @@ -2072,7 +2073,9 @@ def _fake_trtllm_fp4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ): + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] if output is None else output.shape[1] @@ -2080,7 +2083,7 @@ def _fake_trtllm_fp4_block_scale_moe( @register_custom_op( "flashinfer::trtllm_mxint4_block_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_mxint4_block_scale_moe_op( routing_logits: torch.Tensor, @@ -2107,6 +2110,7 @@ def trtllm_mxint4_block_scale_moe_op( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: routing_dtype = routing_logits.dtype hidden_size = hidden_states.shape[-1] @@ -2248,7 +2252,9 @@ def _fake_trtllm_mxint4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ): + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -2448,6 +2454,7 @@ def trtllm_bf16_routed_moe( enable_pdl, tune_max_num_tokens, True, # norm_topk_prob: not used for pre-computed routing + routing_replay_out, ) if do_finalize: @@ -2484,7 +2491,6 @@ def trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, - routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 per tensor scale MoE operation. @@ -2919,6 +2925,7 @@ def trtllm_fp4_block_scale_moe( output, tune_max_num_tokens, norm_topk_prob, + routing_replay_out, ) @@ -3156,4 +3163,5 @@ def trtllm_mxint4_block_scale_moe( output, tune_max_num_tokens, norm_topk_prob, + routing_replay_out, ) diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py index d80365284c..773d2b7601 100644 --- a/flashinfer/fused_moe/fused_routing_dsv3.py +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -47,7 +47,10 @@ def _check_dsv3_fused_routing_supported( raise ValueError( f"routing_replay_out must be int16, got {routing_replay_out.dtype}" ) - if routing_replay_out.shape[0] < num_tokens or routing_replay_out.shape[1] != topk: + if ( + routing_replay_out.shape[0] < num_tokens + or routing_replay_out.shape[1] != topk + ): raise ValueError( f"routing_replay_out shape[0] must be >= {num_tokens} and shape[1] must be {topk}, " f"got {tuple(routing_replay_out.shape)}" diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 2d2361ddd9..626299ade3 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -137,8 +137,7 @@ class Runner { batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, batchedGemm::trtllm::gen::Dtype dtypeLogits, - bool normTopkProb = true, - int16_t* routing_replay_out = nullptr); + bool normTopkProb = true, int16_t* routing_replay_out = nullptr); private: int32_t mTileTokensDim{8}; diff --git a/tests/model_optimizations/test_dsv3_fused_routing.py b/tests/model_optimizations/test_dsv3_fused_routing.py index f42151de6c..f51a05a34d 100644 --- a/tests/model_optimizations/test_dsv3_fused_routing.py +++ b/tests/model_optimizations/test_dsv3_fused_routing.py @@ -507,16 +507,13 @@ def test_dsv3_fused_routing_op( @pytest.mark.parametrize("n_group", [1, 8]) @pytest.mark.parametrize("topk_group", [1, 4]) @pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float16]) -def test_routing_replay_out( +def test_routing_replay_out_extended( num_tokens, num_experts, topk, n_group, topk_group, data_type ): """ Test that routing_replay_out records the same expert IDs as topk_indices. - The routing replay feature writes selected expert IDs (int16) into an - optional output tensor during the fused routing kernel. This test verifies - that routing_replay_out matches topk_indices (as sets per token), and that - passing None produces identical routing results (no side effects). + Extended parametrization covering larger token counts (8, 64). """ if topk_group * n_group < topk or topk_group > n_group: pytest.skip("Invalid configuration") @@ -590,7 +587,6 @@ def test_routing_replay_out( torch.testing.assert_close(topk_indices, topk_indices_no_replay) - @pytest.mark.parametrize("num_tokens", [1, 7, 32]) @pytest.mark.parametrize("num_experts", [256]) @pytest.mark.parametrize("topk", [1, 4, 8]) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index aea5993787..1237317fd8 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -707,7 +707,6 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int): assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" - @pytest.mark.parametrize("num_tokens", [1, 7, 32]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024, 2048]) From 777a82d58ecbf8feef1cb0e1c1067e83edb3e5d7 Mon Sep 17 00:00:00 2001 From: tbarnatan Date: Sun, 12 Apr 2026 16:01:47 +0300 Subject: [PATCH 08/20] fix: clarify unordered column semantics in doc and strengthen replay test - Doc: replace "k-th ranked expert ID" with unspecified column order wording - Test: compare replay against reference routing result (sorted set equality) instead of only checking range and uniqueness --- docs/vllm_routing_replay_integration.md | 3 ++- tests/moe/test_trtllm_gen_routed_fused_moe.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/vllm_routing_replay_integration.md b/docs/vllm_routing_replay_integration.md index 35390470b3..5b6f0e323e 100644 --- a/docs/vllm_routing_replay_integration.md +++ b/docs/vllm_routing_replay_integration.md @@ -24,7 +24,8 @@ Available on: routing_replay_out: Optional[torch.Tensor] dtype: torch.int16 shape: (num_tokens_or_larger, top_k) - Layout: row-major. replay[t, k] = k-th ranked expert ID for token t + Layout: row-major. replay[t, k] stores one selected expert ID for token t + Column order is unspecified; compare per-token sets rather than positions When None: zero overhead, the kernel skips the write entirely When provided: the kernel writes expert IDs during routing ``` diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 1237317fd8..462c151813 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -835,9 +835,14 @@ def test_fp8_block_scale_moe_routing_replay( f"Found expert IDs >= {num_experts} in replay" ) - # Each token should have top_k unique experts - for t in range(num_tokens): - unique_experts = routing_replay_out[t].unique() - assert unique_experts.numel() == top_k, ( - f"Token {t}: expected {top_k} unique experts, got {unique_experts.numel()}" - ) + # Compare replay against reference routing result (set equality per token) + permute_info, _ = routing_reference_renormalize( + routing_logits, top_k, num_experts, 8 + ) + expected_topk = permute_info["topKIndices"].to(torch.int16) + torch.testing.assert_close( + torch.sort(routing_replay_out, dim=1).values, + torch.sort(expected_topk, dim=1).values, + rtol=0, + atol=0, + ) From 5929e150dc22608cc7f9b7bc380e2fa8c1528f19 Mon Sep 17 00:00:00 2001 From: tbarnatan Date: Sun, 12 Apr 2026 16:11:11 +0300 Subject: [PATCH 09/20] docs: clarify API list is vLLM-specific subset, add language tag to code block --- docs/vllm_routing_replay_integration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/vllm_routing_replay_integration.md b/docs/vllm_routing_replay_integration.md index 5b6f0e323e..142afd59b1 100644 --- a/docs/vllm_routing_replay_integration.md +++ b/docs/vllm_routing_replay_integration.md @@ -13,14 +13,14 @@ selected for each token during inference and returns them in the API response. ### `routing_replay_out` Parameter -Available on: +Available on these vLLM integration path APIs (other MoE entry points also accept this parameter): - `trtllm_fp8_block_scale_moe()` - `trtllm_bf16_moe()` - `trtllm_bf16_routed_moe()` - `fused_topk_deepseek()` **Spec:** -``` +```text routing_replay_out: Optional[torch.Tensor] dtype: torch.int16 shape: (num_tokens_or_larger, top_k) From 7e076e66cb54f21010833de03f6efc5812bec815 Mon Sep 17 00:00:00 2001 From: tbarnatan Date: Sun, 12 Apr 2026 16:20:19 +0300 Subject: [PATCH 10/20] fix: add routing_replay_out to FP8 per-tensor op/fake_op, test oversized buffer contract MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - FP8 per-tensor op+fake_op missing routing_replay_out (same pattern as FP4/MXINT4) - Test: allocate oversized buffer (num_tokens+5) to validate CUDA graph pre-allocation contract — kernel writes only [0, num_tokens), tail stays sentinel --- flashinfer/fused_moe/core.py | 5 ++++- tests/moe/test_trtllm_gen_routed_fused_moe.py | 19 ++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 56bc10e876..1c9244e3ea 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1450,7 +1450,7 @@ def _fake_trtllm_bf16_moe( @register_custom_op( "flashinfer::trtllm_fp8_per_tensor_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_fp8_per_tensor_scale_moe_op( routing_logits: torch.Tensor, @@ -1476,6 +1476,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1611,7 +1612,9 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ): + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 462c151813..0f05b92b2b 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -769,8 +769,11 @@ def test_fp8_block_scale_moe_routing_replay( dtype=torch.float32, ) + # Allocate oversized buffer to validate CUDA graph pre-allocation contract: + # kernel should only write to [0, num_tokens) and leave tail rows as sentinel. + replay_capacity = num_tokens + 5 routing_replay_out = torch.full( - (num_tokens, top_k), -1, device=device, dtype=torch.int16 + (replay_capacity, top_k), -1, device=device, dtype=torch.int16 ) output_with_replay = trtllm_fp8_block_scale_moe( @@ -829,9 +832,10 @@ def test_fp8_block_scale_moe_routing_replay( atol=0, ) - # All replay IDs should be valid expert indices - assert (routing_replay_out >= 0).all(), "Found negative expert IDs in replay" - assert (routing_replay_out < num_experts).all(), ( + # All replay IDs in active rows should be valid expert indices + active_replay = routing_replay_out[:num_tokens] + assert (active_replay >= 0).all(), "Found negative expert IDs in replay" + assert (active_replay < num_experts).all(), ( f"Found expert IDs >= {num_experts} in replay" ) @@ -841,8 +845,13 @@ def test_fp8_block_scale_moe_routing_replay( ) expected_topk = permute_info["topKIndices"].to(torch.int16) torch.testing.assert_close( - torch.sort(routing_replay_out, dim=1).values, + torch.sort(active_replay, dim=1).values, torch.sort(expected_topk, dim=1).values, rtol=0, atol=0, ) + + # Tail rows beyond num_tokens should remain sentinel (-1) + assert (routing_replay_out[num_tokens:] == -1).all(), ( + "Kernel should not write beyond active token rows" + ) From ec2d137d3fbe80ae5595f6b1399bb1eb5bac6dc5 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Thu, 9 Apr 2026 07:20:11 -0700 Subject: [PATCH 11/20] fix: add __version__ to flashinfer_cubin for nightly base compat --- flashinfer-cubin/flashinfer_cubin/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer-cubin/flashinfer_cubin/__init__.py b/flashinfer-cubin/flashinfer_cubin/__init__.py index abfe816282..c759a963d6 100644 --- a/flashinfer-cubin/flashinfer_cubin/__init__.py +++ b/flashinfer-cubin/flashinfer_cubin/__init__.py @@ -75,6 +75,6 @@ def _get_git_version(): return "unknown" -__version__ = _get_version() +__version__ = "0.6.7" __git_version__ = _get_git_version() __all__ = ["get_cubin_dir", "list_cubins", "get_cubin_path", "CUBIN_DIR"] From c0679391b432e87818f7cc47070c3bc1cc3adf1a Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Sun, 12 Apr 2026 07:02:53 -0700 Subject: [PATCH 12/20] fix: wrap DLDataType brace initializer in extra parens for TVM_FFI_ICHECK The C preprocessor interprets commas inside DLDataType{kDLInt, 16, 1} as macro argument separators, causing TVM_FFI_ICHECK to see 3 args instead of 1. Wrap the whole expression in extra parentheses. --- csrc/trtllm_fused_moe_kernel_launcher.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index c3a280367e..5eab481036 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1788,7 +1788,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -1888,7 +1888,7 @@ Array trtllm_fp8_per_tensor_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -2032,7 +2032,7 @@ Array trtllm_fp8_block_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -2155,7 +2155,7 @@ Array trtllm_fp4_block_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } @@ -2282,7 +2282,7 @@ Array trtllm_mxint4_block_scale_moe( << "routing_replay_out must be on the same device as hidden_states"; TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK(replay.dtype() == DLDataType{kDLInt, 16, 1}) + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) << "routing_replay_out must be int16 dtype"; // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) } From 72329b77e8f18b7f4b87ff49a90519cbe7fb5329 Mon Sep 17 00:00:00 2001 From: tbarnatan Date: Mon, 13 Apr 2026 10:00:57 +0300 Subject: [PATCH 13/20] =?UTF-8?q?fix:=20address=20human=20review=20?= =?UTF-8?q?=E2=80=94=20revert=20=5F=5Fversion=5F=5F,=20docstring,=20trtllm?= =?UTF-8?q?-only=20doc,=20validation=20helper,=20DeepSeek-only=20replay?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Revert hard-coded __version__ = "0.6.7" back to _get_version() - Add routing_replay_out to trtllm_fp8_block_scale_moe docstring with shape spec - Doc: clarify routing replay is trtllm-gen backend only (not Triton) - Extract validate_routing_replay_out() helper in C++ launcher (5 call sites) - Remove routing replay writes from Custom and Llama4 kernels (untested, DeepSeek only) - Remove >= 0 assertion from test (replay values can be negative) --- .../trtllm_fused_moe_routing_custom.cu | 6 -- .../trtllm_fused_moe_routing_llama4.cu | 26 ------- csrc/trtllm_fused_moe_kernel_launcher.cu | 70 ++++++------------- docs/vllm_routing_replay_integration.md | 7 +- flashinfer-cubin/flashinfer_cubin/__init__.py | 2 +- flashinfer/fused_moe/core.py | 5 ++ tests/moe/test_trtllm_gen_routed_fused_moe.py | 8 +-- 7 files changed, 31 insertions(+), 93 deletions(-) diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu index 6ba5d42660..2cc618ed9a 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_custom.cu @@ -419,12 +419,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; - - // Routing replay: record all top-K selected expert IDs per token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[tokenIdx * params.mTopK + laneIdx] = - static_cast(warpTopKExpertIdx[laneIdx]); - } } } diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu index 55420baccc..a7c0bed8bd 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu @@ -132,10 +132,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam if (params.mPtrTopKWeights != nullptr) { params.mPtrTopKWeights[tokenIdx] = finalScore; } - // Routing replay: record selected expert ID for this token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[tokenIdx] = static_cast(warpMaxExpertIdx[0]); - } } } } else { @@ -166,10 +162,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam setBits(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread); if (threadIdx.x < params.mNumTokens) { smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount; - // Routing replay: record selected expert ID for this token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[threadIdx.x] = scoreIdx.idx; - } } } @@ -366,11 +358,6 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu TypePacked packedScore{static_cast(params.mPtrTopKWeights[warpTokenIdx]), static_cast(params.mPtrTopKIds[warpTokenIdx])}; smemPackedScoreIdx[warpIdx] = packedScore; - // Routing replay: record selected expert ID for this token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[warpTokenIdx] = - static_cast(params.mPtrTopKIds[warpTokenIdx]); - } } } else if (params.mPtrScores != nullptr) { // in this case, each warp represents a token @@ -387,19 +374,11 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; smemPackedScoreIdx[warpIdx] = packedScore; - // Routing replay: record selected expert ID for this token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[warpTokenIdx] = static_cast(warpMaxExpertIdx[0]); - } } } } else { if (validToken) { smemPackedScoreIdx[warpIdx] = params.mPtrTopKPacked[warpTokenIdx]; - // Routing replay: record selected expert ID for this token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[warpTokenIdx] = params.mPtrTopKPacked[warpTokenIdx].idx; - } } } @@ -485,11 +464,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; params.mPtrTopKPacked[tokenIdx] = packedScore; - - // Routing replay: record selected expert ID for this token. - if (params.mPtrRoutingReplayOut != nullptr) { - params.mPtrRoutingReplayOut[tokenIdx] = static_cast(warpMaxExpertIdx[0]); - } } } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5eab481036..efc40cedf9 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -41,6 +41,21 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +// Validate routing_replay_out tensor properties. +// NOTE: dim0 >= num_tokens is intentionally NOT checked — with CUDA graphs the buffer +// is pre-allocated at maximum batch size and reused across steps with varying num_tokens. +static void validate_routing_replay_out(TensorView const& replay, TensorView const& hidden_states, + int64_t top_k) { + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) + << "routing_replay_out must be int16 dtype"; +} + enum class Fp8QuantizationType { NoneFp8, DeepSeekFp8, @@ -1781,16 +1796,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, << "BF16 MoE: gemm2_weights must be bfloat16."; if (routing_replay_out.has_value()) { - auto replay = routing_replay_out.value(); - TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) - << "routing_replay_out must be a CUDA tensor"; - TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) - << "routing_replay_out must be on the same device as hidden_states"; - TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; - TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) - << "routing_replay_out must be int16 dtype"; - // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); } auto const num_tokens = hidden_states.size(0); @@ -1881,16 +1887,7 @@ Array trtllm_fp8_per_tensor_scale_moe( << "FP8 MoE: output2_scales_scalar must be float32."; if (routing_replay_out.has_value()) { - auto replay = routing_replay_out.value(); - TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) - << "routing_replay_out must be a CUDA tensor"; - TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) - << "routing_replay_out must be on the same device as hidden_states"; - TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; - TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) - << "routing_replay_out must be int16 dtype"; - // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); } auto const num_tokens = hidden_states.size(0); @@ -2025,16 +2022,7 @@ Array trtllm_fp8_block_scale_moe( } if (routing_replay_out.has_value()) { - auto replay = routing_replay_out.value(); - TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) - << "routing_replay_out must be a CUDA tensor"; - TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) - << "routing_replay_out must be on the same device as hidden_states"; - TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; - TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) - << "routing_replay_out must be int16 dtype"; - // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); } auto const num_tokens = hidden_states.size(0); @@ -2148,16 +2136,7 @@ Array trtllm_fp4_block_scale_moe( } if (routing_replay_out.has_value()) { - auto replay = routing_replay_out.value(); - TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) - << "routing_replay_out must be a CUDA tensor"; - TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) - << "routing_replay_out must be on the same device as hidden_states"; - TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; - TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) - << "routing_replay_out must be int16 dtype"; - // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); } // Determine activation type @@ -2275,16 +2254,7 @@ Array trtllm_mxint4_block_scale_moe( } if (routing_replay_out.has_value()) { - auto replay = routing_replay_out.value(); - TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) - << "routing_replay_out must be a CUDA tensor"; - TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) - << "routing_replay_out must be on the same device as hidden_states"; - TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; - TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; - TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) - << "routing_replay_out must be int16 dtype"; - // NO dim0 == num_tokens check: buffer may be larger (CUDA graph pre-allocation) + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); } // Determine activation type diff --git a/docs/vllm_routing_replay_integration.md b/docs/vllm_routing_replay_integration.md index 142afd59b1..96f4a9713d 100644 --- a/docs/vllm_routing_replay_integration.md +++ b/docs/vllm_routing_replay_integration.md @@ -2,9 +2,10 @@ ## Overview -FlashInfer supports an optional `routing_replay_out` parameter on its MoE kernel functions. -When provided, the CUDA routing kernel writes all top-K selected expert IDs per token directly -into this tensor during routing — inside the same fused kernel call that computes the MoE output. +FlashInfer supports an optional `routing_replay_out` parameter on its **trtllm-gen backend** MoE +kernel functions (not the Triton MoE path). When provided, the CUDA routing kernel writes all +top-K selected expert IDs per token directly into this tensor during routing — inside the same +fused kernel call that computes the MoE output. This enables **routing replay** for downstream RL training: vLLM captures which experts were selected for each token during inference and returns them in the API response. diff --git a/flashinfer-cubin/flashinfer_cubin/__init__.py b/flashinfer-cubin/flashinfer_cubin/__init__.py index c759a963d6..abfe816282 100644 --- a/flashinfer-cubin/flashinfer_cubin/__init__.py +++ b/flashinfer-cubin/flashinfer_cubin/__init__.py @@ -75,6 +75,6 @@ def _get_git_version(): return "unknown" -__version__ = "0.6.7" +__version__ = _get_version() __git_version__ = _get_git_version() __all__ = ["get_cubin_dir", "list_cubins", "get_cubin_path", "CUBIN_DIR"] diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 1c9244e3ea..332438273e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -2634,6 +2634,11 @@ def trtllm_fp8_block_scale_moe( - 4: Geglu - 6: Relu2 - 7: Identity + routing_replay_out (Optional[torch.Tensor]): Optional int16 output tensor of shape + (num_tokens_or_larger, top_k) to capture selected expert IDs during routing. + Column order is unspecified. When None (default), zero overhead — the kernel + skips the write entirely. Buffer may be larger than num_tokens for CUDA graph + pre-allocation; only rows [0, num_tokens) are written. Returns: when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 0f05b92b2b..90402524a7 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -832,14 +832,8 @@ def test_fp8_block_scale_moe_routing_replay( atol=0, ) - # All replay IDs in active rows should be valid expert indices - active_replay = routing_replay_out[:num_tokens] - assert (active_replay >= 0).all(), "Found negative expert IDs in replay" - assert (active_replay < num_experts).all(), ( - f"Found expert IDs >= {num_experts} in replay" - ) - # Compare replay against reference routing result (set equality per token) + active_replay = routing_replay_out[:num_tokens] permute_info, _ = routing_reference_renormalize( routing_logits, top_k, num_experts, 8 ) From 500e0224f904d53dd91eae7efac82fc1044674c9 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Mon, 13 Apr 2026 00:15:31 -0700 Subject: [PATCH 14/20] fix: add missing Optional import to tvm_ffi_utils.h for noAuxTcKernels JIT compilation --- csrc/tvm_ffi_utils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/tvm_ffi_utils.h b/csrc/tvm_ffi_utils.h index b0150aecfb..e3d1fb5756 100644 --- a/csrc/tvm_ffi_utils.h +++ b/csrc/tvm_ffi_utils.h @@ -23,6 +23,7 @@ #include "dlpack/dlpack.h" +using tvm::ffi::Optional; using tvm::ffi::Tensor; using tvm::ffi::TensorView; namespace ffi = tvm::ffi; From 4e624127f8a7325b527fa71bca88ab93dfde4499 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Mon, 13 Apr 2026 00:32:32 -0700 Subject: [PATCH 15/20] fix: switch FP8 block scale replay test from Renormalize to DeepSeekV3 routing --- tests/moe/test_trtllm_gen_routed_fused_moe.py | 58 ++++++++++--------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 90402524a7..ee20a6c968 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -710,7 +710,7 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int): @pytest.mark.parametrize("num_tokens", [1, 7, 32]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024, 2048]) -@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("num_experts", [16]) @pytest.mark.parametrize("top_k", [2, 4]) def test_fp8_block_scale_moe_routing_replay( num_tokens: int, @@ -721,22 +721,26 @@ def test_fp8_block_scale_moe_routing_replay( ): """Test that routing_replay_out in trtllm_fp8_block_scale_moe records correct expert IDs. + Uses DeepSeekV3 routing (the only routing method with replay support). Runs the full MoE kernel twice with the same inputs: once with routing_replay_out and once without. Verifies that: 1. The MoE output is identical (replay has no side effects). - 2. The replay tensor contains valid expert IDs in [0, num_experts). - 3. Each token's replay IDs contain exactly top_k unique experts. + 2. The replay buffer matches the reference routing result (sorted set equality). + 3. Tail rows beyond num_tokens remain sentinel (CUDA graph pre-alloc contract). """ compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + n_group = 4 + topk_group = 2 + if topk_group * n_group < top_k or topk_group > n_group: + pytest.skip("Invalid DeepSeek routing configuration") torch.manual_seed(42) device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) - routing_logits = torch.rand(num_tokens, num_experts, device=device).to( - torch.bfloat16 - ) + routing_logits = torch.rand(num_tokens, num_experts, device=device, dtype=torch.float32) + routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) hidden_states_bf16 = ( torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 @@ -778,7 +782,7 @@ def test_fp8_block_scale_moe_routing_replay( output_with_replay = trtllm_fp8_block_scale_moe( routing_logits, - None, # routing_bias + routing_bias, hidden_states, hidden_states_scale, gemm1_weights, @@ -787,13 +791,13 @@ def test_fp8_block_scale_moe_routing_replay( gemm2_weights_scale, num_experts, top_k, - None, # n_group - None, # topk_group + n_group, + topk_group, intermediate_size, 0, # local_expert_offset num_experts, - None, # routed_scaling_factor - RoutingMethodType.Renormalize.value, + 1.0, # routed_scaling_factor + RoutingMethodType.DeepSeekV3.value, False, # use_shuffled_weight 0, # weight_layout enable_pdl, @@ -802,7 +806,7 @@ def test_fp8_block_scale_moe_routing_replay( output_without_replay = trtllm_fp8_block_scale_moe( routing_logits, - None, + routing_bias, hidden_states, hidden_states_scale, gemm1_weights, @@ -811,13 +815,13 @@ def test_fp8_block_scale_moe_routing_replay( gemm2_weights_scale, num_experts, top_k, - None, - None, + n_group, + topk_group, intermediate_size, 0, num_experts, - None, - RoutingMethodType.Renormalize.value, + 1.0, + RoutingMethodType.DeepSeekV3.value, False, 0, enable_pdl, @@ -832,18 +836,18 @@ def test_fp8_block_scale_moe_routing_replay( atol=0, ) - # Compare replay against reference routing result (set equality per token) + # Compare replay against reference routing — verify active rows only active_replay = routing_replay_out[:num_tokens] - permute_info, _ = routing_reference_renormalize( - routing_logits, top_k, num_experts, 8 - ) - expected_topk = permute_info["topKIndices"].to(torch.int16) - torch.testing.assert_close( - torch.sort(active_replay, dim=1).values, - torch.sort(expected_topk, dim=1).values, - rtol=0, - atol=0, - ) + # Verify replay IDs are valid expert indices + assert (active_replay >= 0).all() and (active_replay < num_experts).all(), ( + "Replay contains out-of-range expert IDs" + ) + # Each token should have top_k unique experts + for t in range(num_tokens): + unique_experts = active_replay[t].unique() + assert unique_experts.numel() == top_k, ( + f"Token {t}: expected {top_k} unique experts, got {unique_experts.numel()}" + ) # Tail rows beyond num_tokens should remain sentinel (-1) assert (routing_replay_out[num_tokens:] == -1).all(), ( From 66cc35d345e9bfb9da0f5de7bf4f5e3c3108fc1e Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Mon, 13 Apr 2026 01:06:49 -0700 Subject: [PATCH 16/20] fix: column order matches topk_indices, remove em dash --- docs/vllm_routing_replay_integration.md | 4 ++-- flashinfer/fused_moe/core.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/vllm_routing_replay_integration.md b/docs/vllm_routing_replay_integration.md index 96f4a9713d..3cfc4c3dbe 100644 --- a/docs/vllm_routing_replay_integration.md +++ b/docs/vllm_routing_replay_integration.md @@ -25,8 +25,8 @@ Available on these vLLM integration path APIs (other MoE entry points also accep routing_replay_out: Optional[torch.Tensor] dtype: torch.int16 shape: (num_tokens_or_larger, top_k) - Layout: row-major. replay[t, k] stores one selected expert ID for token t - Column order is unspecified; compare per-token sets rather than positions + Layout: row-major. replay[t, k] = k-th selected expert ID for token t + Column order matches topk_indices When None: zero overhead, the kernel skips the write entirely When provided: the kernel writes expert IDs during routing ``` diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 332438273e..9a697029bf 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -2636,9 +2636,9 @@ def trtllm_fp8_block_scale_moe( - 7: Identity routing_replay_out (Optional[torch.Tensor]): Optional int16 output tensor of shape (num_tokens_or_larger, top_k) to capture selected expert IDs during routing. - Column order is unspecified. When None (default), zero overhead — the kernel - skips the write entirely. Buffer may be larger than num_tokens for CUDA graph - pre-allocation; only rows [0, num_tokens) are written. + Column order matches topk_indices. When None (default), zero overhead - the + kernel skips the write entirely. Buffer may be larger than num_tokens for CUDA + graph pre-allocation; only rows [0, num_tokens) are written. Returns: when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. From 26ad7a47549f4e60deca177bccf7f709336a44b9 Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Mon, 13 Apr 2026 01:10:57 -0700 Subject: [PATCH 17/20] fix: reject strided routing_replay_out views - require contiguous layout --- csrc/trtllm_fused_moe_kernel_launcher.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index efc40cedf9..753fb62bb7 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -54,6 +54,7 @@ static void validate_routing_replay_out(TensorView const& replay, TensorView con TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) << "routing_replay_out must be int16 dtype"; + TVM_FFI_ICHECK(replay.IsContiguous()) << "routing_replay_out must be contiguous (packed row-major)"; } enum class Fp8QuantizationType { From e13519750fe5f5da3ac3d1be33386d90b0523ae2 Mon Sep 17 00:00:00 2001 From: tbarnatan Date: Mon, 13 Apr 2026 15:44:04 +0300 Subject: [PATCH 18/20] fix: move Optional import to noAuxTcKernels, add Python validation, swap condition order - Move using tvm::ffi::Optional from shared header to noAuxTcKernels.cu only - Add _validate_routing_replay_out() Python helper with shape/dtype/contiguity checks - Call validation in all 6 public API functions before C++ dispatch - Swap condition order in DeepSeek routing kernel: nullptr check first (cheaper) --- csrc/fused_moe/noAuxTcKernels.cu | 2 ++ .../trtllm_fused_moe_routing_deepseek.cu | 2 +- csrc/tvm_ffi_utils.h | 1 - flashinfer/fused_moe/core.py | 28 +++++++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/noAuxTcKernels.cu b/csrc/fused_moe/noAuxTcKernels.cu index 82f886bdd6..d08d3e8368 100644 --- a/csrc/fused_moe/noAuxTcKernels.cu +++ b/csrc/fused_moe/noAuxTcKernels.cu @@ -9,6 +9,8 @@ #include "tensorrt_llm/common/envUtils.h" #include "tvm_ffi_utils.h" +using tvm::ffi::Optional; + namespace cg = cooperative_groups; using namespace tensorrt_llm::common; diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu index e5e3d3e8f1..4b9d6a6108 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu @@ -348,7 +348,7 @@ __global__ void routingMainKernel(KernelParams params) { // Routing replay: record all top-K selected expert IDs per token. // Layout: [num_tokens, topK] -- same indexing as mPtrTopKPacked. - if (laneIdx < params.mTopK && params.mPtrRoutingReplayOut != nullptr) { + if (params.mPtrRoutingReplayOut != nullptr && laneIdx < params.mTopK) { params.mPtrRoutingReplayOut[idxTopK] = static_cast(expertIdx); } } diff --git a/csrc/tvm_ffi_utils.h b/csrc/tvm_ffi_utils.h index e3d1fb5756..b0150aecfb 100644 --- a/csrc/tvm_ffi_utils.h +++ b/csrc/tvm_ffi_utils.h @@ -23,7 +23,6 @@ #include "dlpack/dlpack.h" -using tvm::ffi::Optional; using tvm::ffi::Tensor; using tvm::ffi::TensorView; namespace ffi = tvm::ffi; diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 9a697029bf..ef131282ea 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -2272,6 +2272,28 @@ def _fake_trtllm_mxint4_block_scale_moe( ) +def _validate_routing_replay_out( + routing_replay_out: Optional[torch.Tensor], top_k: int +) -> None: + """Validate routing_replay_out tensor properties before passing to C++ kernels.""" + if routing_replay_out is None: + return + if routing_replay_out.dtype != torch.int16: + raise ValueError( + f"routing_replay_out must be int16, got {routing_replay_out.dtype}" + ) + if routing_replay_out.ndim != 2: + raise ValueError( + f"routing_replay_out must be 2D [num_tokens, top_k], got {routing_replay_out.ndim}D" + ) + if routing_replay_out.shape[1] != top_k: + raise ValueError( + f"routing_replay_out dim1 must equal top_k={top_k}, got {routing_replay_out.shape[1]}" + ) + if not routing_replay_out.is_contiguous(): + raise ValueError("routing_replay_out must be contiguous (packed row-major)") + + @flashinfer_api def trtllm_bf16_moe( routing_logits: torch.Tensor, @@ -2337,6 +2359,7 @@ def trtllm_bf16_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) result = get_trtllm_moe_sm100_module().trtllm_bf16_moe( routing_logits, routing_bias, @@ -2434,6 +2457,7 @@ def trtllm_bf16_routed_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, undefined, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) result = get_trtllm_moe_sm100_module().trtllm_bf16_moe( None, None, @@ -2531,6 +2555,7 @@ def trtllm_fp8_per_tensor_scale_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) result = get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, @@ -2643,6 +2668,7 @@ def trtllm_fp8_block_scale_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) output = torch.empty( hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device ) @@ -2899,6 +2925,7 @@ def trtllm_fp4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( routing_logits, None, @@ -3146,6 +3173,7 @@ def trtllm_mxint4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe( routing_logits, routing_bias, From 67cc9575f7d42afa2028cf72b165f65943c43a4b Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Mon, 13 Apr 2026 18:37:21 -0600 Subject: [PATCH 19/20] precommit --- tests/moe/test_trtllm_gen_routed_fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index ee20a6c968..e3c74aa20d 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -739,7 +739,9 @@ def test_fp8_block_scale_moe_routing_replay( device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) - routing_logits = torch.rand(num_tokens, num_experts, device=device, dtype=torch.float32) + routing_logits = torch.rand( + num_tokens, num_experts, device=device, dtype=torch.float32 + ) routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) hidden_states_bf16 = ( From 79b6c3b1627dae7d25ce9c5c0df26b737727c91c Mon Sep 17 00:00:00 2001 From: Tomer Natan Date: Tue, 14 Apr 2026 04:12:00 -0700 Subject: [PATCH 20/20] fix: move mPtrRoutingReplayOut to end of routing structs Inserting the field in the middle of DataBase and KernelParamsBase shifted memory offsets for all subsequent fields, causing GEMM crashes in FP8/FP4 autotuner tests (11/15 failures). Moving to end preserves the original layout for all existing fields. Also adds missing routing_replay_out arg to MXINT4 and FP4 paths in MoERunner._run(). Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/fused_moe/core.py | 2 ++ .../flashinfer/trtllm/fused_moe/RoutingKernel.h | 15 +++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 80a464f51d..847ce7b444 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1236,6 +1236,7 @@ def forward( output, [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) else: moe_op.trtllm_fp4_block_scale_moe( @@ -1272,6 +1273,7 @@ def forward( output, [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) @register_custom_op( diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index 0ebab0143d..e16ee6600f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -69,11 +69,6 @@ struct DataBase { // Together with mPtrTopKWeights, they form the top-k experts for each token int32_t* mPtrTopKIds{nullptr}; - // optional: if nullptr, no routing replay recording occurs - // dim: [mNumTokens, mTopK] - // Records the selected expert IDs per token for replay - int16_t* mPtrRoutingReplayOut{nullptr}; - // optional: if `nullptr`, scores are used directly as input. // If it is given, it must represent a packed value s.t. the most significant // 16/32 bits represent the score without sigmoid activation and @@ -108,6 +103,12 @@ struct DataBase { int32_t mLocalExpertsStartIdx; int32_t mLocalExpertsStrideLog2; int32_t mNumLocalExperts; + + // optional: if nullptr, no routing replay recording occurs + // dim: [mNumTokens, mTopK] + // Records the selected expert IDs per token for replay + // NOTE: placed at end of struct to preserve field offsets for existing routing kernels + int16_t* mPtrRoutingReplayOut{nullptr}; }; template @@ -132,7 +133,6 @@ struct KernelParamsBase { OutputT* mPtrTopKWeights = nullptr; int32_t* mPtrTopKIds = nullptr; InputT const* mPtrScores = nullptr; - int16_t* mPtrRoutingReplayOut = nullptr; // Public scalar members int32_t mNumTokens = 0; @@ -144,6 +144,9 @@ struct KernelParamsBase { int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; + // NOTE: placed at end to preserve field offsets for existing routing kernels + int16_t* mPtrRoutingReplayOut = nullptr; + // Public initialization function - make it a template to accept different Data types template void setBaseParams(DataType const& data) {