diff --git a/csrc/fused_moe/noAuxTcKernels.cu b/csrc/fused_moe/noAuxTcKernels.cu index 1f57d9b57b..d08d3e8368 100644 --- a/csrc/fused_moe/noAuxTcKernels.cu +++ b/csrc/fused_moe/noAuxTcKernels.cu @@ -9,6 +9,8 @@ #include "tensorrt_llm/common/envUtils.h" #include "tvm_ffi_utils.h" +using tvm::ffi::Optional; + namespace cg = cooperative_groups; using namespace tensorrt_llm::common; @@ -30,7 +32,8 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx int64_t const numGroup, int64_t const topkGroup, int64_t const topk, int64_t const numExperts, int64_t const numExpertsPerGroup, - double const routedScalingFactor) { + double const routedScalingFactor, + int16_t* routingReplayOut) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif @@ -212,6 +215,10 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx if (laneIdx < topk) { topkValues[laneIdx] = static_cast(finalScore); topkIndices[laneIdx] = expertIdx; + // Routing replay: record selected expert IDs per token + if (routingReplayOut != nullptr) { + routingReplayOut[blockIdx.x * topk + laneIdx] = static_cast(expertIdx); + } } } @@ -224,7 +231,8 @@ template void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, - bool const launch_with_pdl, cudaStream_t const stream) { + bool const launch_with_pdl, cudaStream_t const stream, + int16_t* routing_replay_out) { // Check if we can use the optimized deepseek_v3_topk_kernel bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts); @@ -262,7 +270,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, topk_indices, bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, - routed_scaling_factor); + routed_scaling_factor, routing_replay_out); sync_check_cuda_error(stream); } else { // TODO: call the generic path (previous implementation) or signal unsupported config. @@ -279,7 +287,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \ int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \ int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \ - bool const launch_with_pdl, cudaStream_t const stream); + bool const launch_with_pdl, cudaStream_t const stream, int16_t* routing_replay_out); INSTANTIATE_NOAUX_TC(float, float, float, int32_t); INSTANTIATE_NOAUX_TC(float, half, float, int32_t); @@ -305,7 +313,7 @@ namespace flashinfer::trtllm_dsv3_fused_routing { void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_group, int64_t topk, double routed_scaling_factor, TensorView topk_values, TensorView topk_indices, - bool launch_with_pdl) { + bool launch_with_pdl, Optional routing_replay_out) { auto data_type = scores.dtype(); auto bias_type = bias.dtype(); @@ -342,6 +350,26 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) << "topk_indices must have the same dtype as scores"; + // Validate and extract routing_replay_out + // NOTE: dim0 >= num_tokens is intentionally NOT checked — with CUDA graphs the buffer + // is pre-allocated at maximum batch size and reused across steps with varying num_tokens. + // The kernel only writes to indices [0, num_tokens), so a larger buffer is safe. + constexpr int64_t int16_code_val = encode_dlpack_dtype(DLDataType{kDLInt, 16, 1}); + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + auto replay = routing_replay_out.value(); + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == scores.device().device_id) + << "routing_replay_out must be on the same device as scores"; + TVM_FFI_ICHECK(replay.ndim() == 2) + << "routing_replay_out must be a 2D Tensor [num_tokens, topk]"; + TVM_FFI_ICHECK(replay.sizes()[1] == topk) << "routing_replay_out dim1 must equal topk"; + TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code_val) + << "routing_replay_out must be int16 dtype"; + replay_ptr = reinterpret_cast(replay.data_ptr()); + } + auto stream = get_stream(scores.device()); using namespace tensorrt_llm::kernels; switch (encode_dlpack_dtype(data_type)) { @@ -353,14 +381,14 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float32_code: invokeNoAuxTc( reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case bfloat16_code: invokeNoAuxTc( @@ -368,7 +396,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; default: throw std::invalid_argument( @@ -384,14 +412,14 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float16_code: invokeNoAuxTc( reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case bfloat16_code: invokeNoAuxTc( @@ -399,7 +427,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), reinterpret_cast(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; default: throw std::invalid_argument( @@ -416,7 +444,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float16_code: invokeNoAuxTc<__nv_bfloat16, half, __nv_bfloat16, int32_t>( @@ -424,7 +452,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(bias.data_ptr()), reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; case float32_code: invokeNoAuxTc<__nv_bfloat16, float, __nv_bfloat16, int32_t>( @@ -432,7 +460,7 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g reinterpret_cast(bias.data_ptr()), reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, - topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr); break; default: throw std::invalid_argument( diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu index 6174804d53..316c819bcf 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu @@ -345,6 +345,12 @@ __global__ void routingMainKernel(KernelParams params) { params.mPtrTopKIds == nullptr) { params.mPtrTopKWeights[idxTopK] = finalScore; } + + // Routing replay: record all top-K selected expert IDs per token. + // Layout: [num_tokens, topK] -- same indexing as mPtrTopKPacked. + if (params.mPtrRoutingReplayOut != nullptr && laneIdx < params.mTopK) { + params.mPtrRoutingReplayOut[idxTopK] = static_cast(expertIdx); + } } } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index c396855138..afb74c1263 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -41,6 +41,23 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +// Validate routing_replay_out tensor properties. +// NOTE: dim0 >= num_tokens is intentionally NOT checked — with CUDA graphs the buffer +// is pre-allocated at maximum batch size and reused across steps with varying num_tokens. +static void validate_routing_replay_out(TensorView const& replay, TensorView const& hidden_states, + int64_t top_k) { + TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) + << "routing_replay_out must be a CUDA tensor"; + TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) + << "routing_replay_out must be on the same device as hidden_states"; + TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; + TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; + TVM_FFI_ICHECK((replay.dtype() == DLDataType{kDLInt, 16, 1})) + << "routing_replay_out must be int16 dtype"; + TVM_FFI_ICHECK(replay.IsContiguous()) + << "routing_replay_out must be contiguous (packed row-major)"; +} + enum class Fp8QuantizationType { NoneFp8, DeepSeekFp8, @@ -186,6 +203,9 @@ class FusedMoeLauncher { ActivationType activation_type{ActivationType::Swiglu}; btg::Dtype mDtypeScore{btg::Dtype::Bfloat16}; + // Optional routing replay output: [num_tokens, top_k] int16 tensor + Optional routing_replay_out; + int64_t intermediate_size_factor{2}; public: @@ -214,6 +234,11 @@ class FusedMoeLauncher { activation_type{ActivationType::Swiglu}, intermediate_size_factor{2} {} + public: + void set_routing_replay_out(const Optional& replay_out) { + routing_replay_out = replay_out; + } + protected: // Initialize common data necessary for later. // May throw exception from TVM_FFI_ICHECK. @@ -452,6 +477,11 @@ class FusedMoeLauncher { tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t routing_stream = get_stream(hidden_states.device()); + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + replay_ptr = reinterpret_cast(routing_replay_out.value().data_ptr()); + } + routing_runner.run( args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, @@ -467,7 +497,7 @@ class FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob); + norm_topk_prob, replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1183,6 +1213,11 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; // When using pre-computed routing, pass nullptr as routing_logits to tell the // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + replay_ptr = reinterpret_cast(routing_replay_out.value().data_ptr()); + } + routing_runner.run( use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, @@ -1198,7 +1233,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob); + norm_topk_prob, replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1665,6 +1700,11 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t routing_stream = get_stream(hidden_states.device()); + int16_t* replay_ptr = nullptr; + if (routing_replay_out.has_value()) { + replay_ptr = reinterpret_cast(routing_replay_out.value().data_ptr()); + } + routing_runner.run( args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, @@ -1680,7 +1720,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream, mRoutingLogitsDtype, - norm_topk_prob); + norm_topk_prob, replay_ptr); check_moe(); prepare_moe(moe_tactic); @@ -1725,15 +1765,18 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { } }; -Array trtllm_bf16_moe( - Optional const& routing_logits, Optional const& routing_bias, - TensorView const& expert_indices, TensorView const& expert_weights, - TensorView const& hidden_states, TensorView const& gemm1_weights, - TensorView const& gemm2_weights, TensorView output, int64_t num_experts, int64_t top_k, - Optional n_group, Optional topk_group, int64_t intermediate_size, - int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array moe_tactic, int64_t activation_type, bool norm_topk_prob) { +Array trtllm_bf16_moe(Optional const& routing_logits, + Optional const& routing_bias, + TensorView const& expert_indices, TensorView const& expert_weights, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& gemm2_weights, TensorView output, + int64_t num_experts, int64_t top_k, Optional n_group, + Optional topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, + Optional routed_scaling_factor, int64_t routing_method_type, + bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, + bool enable_pdl, Array moe_tactic, int64_t activation_type, + bool norm_topk_prob, Optional routing_replay_out) { // Just some basic type validation first and leave more checks to the launcher if (routing_logits.has_value()) { TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || @@ -1747,6 +1790,10 @@ Array trtllm_bf16_moe( TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_bfloat16) << "BF16 MoE: gemm2_weights must be bfloat16."; + if (routing_replay_out.has_value()) { + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); + } + auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); auto const activation = validateAndCastActivationType(activation_type); @@ -1785,6 +1832,7 @@ Array trtllm_bf16_moe( gemm2_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, activation, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1810,7 +1858,8 @@ Array trtllm_fp8_per_tensor_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, bool use_routing_scales_on_input, int64_t routing_method_type, bool do_finalize, - bool enable_pdl, Array config_index, int64_t activation_type, bool norm_topk_prob) { + bool enable_pdl, Array config_index, int64_t activation_type, bool norm_topk_prob, + Optional routing_replay_out) { // Basic type validation auto dtype = hidden_states.dtype(); auto activation = validateAndCastActivationType(activation_type); @@ -1828,6 +1877,10 @@ Array trtllm_fp8_per_tensor_scale_moe( TVM_FFI_ICHECK_EQ(output2_scales_scalar.dtype(), dl_float32) << "FP8 MoE: output2_scales_scalar must be float32."; + if (routing_replay_out.has_value()) { + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); + } + auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); @@ -1867,6 +1920,7 @@ Array trtllm_fp8_per_tensor_scale_moe( output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, use_routing_scales_on_input, activation, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1893,7 +1947,7 @@ Array trtllm_fp8_block_scale_moe( int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl, Array config_index, Fp8QuantizationType quantization_type, - int64_t act_type, bool norm_topk_prob) { + int64_t act_type, bool norm_topk_prob, Optional routing_replay_out) { auto activation_type = validateAndCastActivationType(act_type); // DeepSeekFp8 currently uses a TRTLLM runner that hardwires Swiglu activation semantics. // Fail for any other activation to avoid silently running incorrect activation behavior. @@ -1953,6 +2007,10 @@ Array trtllm_fp8_block_scale_moe( TVM_FFI_ICHECK(weight_layout == 0) << "weight_layout must be 0 for MxFp8."; } + if (routing_replay_out.has_value()) { + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); + } + auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); @@ -1987,6 +2045,7 @@ Array trtllm_fp8_block_scale_moe( quantization_type); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, activation_type, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -2019,7 +2078,8 @@ Array trtllm_fp4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t act_type, - TensorView output, Array config_index, bool norm_topk_prob) { + TensorView output, Array config_index, bool norm_topk_prob, + Optional routing_replay_out) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); @@ -2061,6 +2121,10 @@ Array trtllm_fp4_block_scale_moe( << "routing_bias has incorrect shape."; } + if (routing_replay_out.has_value()) { + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); + } + // Determine activation type TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) << "weights must be fp4 packed in uint8."; @@ -2127,6 +2191,7 @@ Array trtllm_fp4_block_scale_moe( launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, /*weight_layout=*/0, static_cast(act_type), mDtypeAct, mDtypeWeights, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } @@ -2152,7 +2217,7 @@ Array trtllm_mxint4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, TensorView output, - Array config_index, bool norm_topk_prob) { + Array config_index, bool norm_topk_prob, Optional routing_replay_out) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); @@ -2174,6 +2239,10 @@ Array trtllm_mxint4_block_scale_moe( << "routing_bias has incorrect shape."; } + if (routing_replay_out.has_value()) { + validate_routing_replay_out(routing_replay_out.value(), hidden_states, top_k); + } + // Determine activation type TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) << "weights must be int4 packed in uint8."; @@ -2211,6 +2280,7 @@ Array trtllm_mxint4_block_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_weights_scale, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale); launcher->init(std::move(args), curr_tile_N, routing_method_type, norm_topk_prob); + launcher->set_routing_replay_out(routing_replay_out); launchers_map[curr_tile_N] = std::move(launcher); } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index e08e78fe10..cf0fe97994 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -59,7 +59,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, btg::Dtype dtypeLogits, - bool normTopkProb) { + bool normTopkProb, int16_t* routing_replay_out) { if (routingMethodType == RoutingMethodType::DeepSeekV3 && nGroup <= 1) { // DeepSeek no-groups case: use routingCustom with SigmoidBias preprocess // and ScaledSumNormalize postprocess. This is more efficient than the full DeepSeek @@ -96,6 +96,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingCustom::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::MiniMax2) { @@ -135,6 +136,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingCustom::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::DeepSeekV3) { @@ -175,6 +177,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLocalExperts = localNumExperts; routingData.mRouteScale = routedScalingFactor; routingData.mUseRoutingSoftmax = false; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingDeepSeek::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Llama4) { FLASHINFER_CHECK(topK == 1, "For Llama routing method, must have topK == 1"); @@ -210,6 +213,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; moe::dev::routing::routingLlama4::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Default /* Softmax -> TopK */ || routingMethodType == RoutingMethodType::Renormalize /* TopK -> Softmax */ @@ -285,6 +289,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; + routingData.mPtrRoutingReplayOut = routing_replay_out; routingCustom::run(routingData, stream); } else { diff --git a/docs/vllm_routing_replay_integration.md b/docs/vllm_routing_replay_integration.md new file mode 100644 index 0000000000..3cfc4c3dbe --- /dev/null +++ b/docs/vllm_routing_replay_integration.md @@ -0,0 +1,83 @@ +# vLLM Routing Replay Integration Guide + +## Overview + +FlashInfer supports an optional `routing_replay_out` parameter on its **trtllm-gen backend** MoE +kernel functions (not the Triton MoE path). When provided, the CUDA routing kernel writes all +top-K selected expert IDs per token directly into this tensor during routing — inside the same +fused kernel call that computes the MoE output. + +This enables **routing replay** for downstream RL training: vLLM captures which experts were +selected for each token during inference and returns them in the API response. + +## API + +### `routing_replay_out` Parameter + +Available on these vLLM integration path APIs (other MoE entry points also accept this parameter): +- `trtllm_fp8_block_scale_moe()` +- `trtllm_bf16_moe()` +- `trtllm_bf16_routed_moe()` +- `fused_topk_deepseek()` + +**Spec:** +```text +routing_replay_out: Optional[torch.Tensor] + dtype: torch.int16 + shape: (num_tokens_or_larger, top_k) + Layout: row-major. replay[t, k] = k-th selected expert ID for token t + Column order matches topk_indices + When None: zero overhead, the kernel skips the write entirely + When provided: the kernel writes expert IDs during routing +``` + +### CUDA Graph Compatibility + +The buffer may be **larger** than `num_tokens`. This is intentional: vLLM pre-allocates +the buffer at `max_num_batched_tokens` and reuses it across CUDA graph replays. The kernel +determines write extent from `routing_logits.shape[0]`, not from `routing_replay_out.shape[0]`. + +There is no strict `dim0 == num_tokens` validation — only `dim0 >= num_tokens` and +`dim1 == top_k`. + +### Memory Layout for vLLM Integration + +vLLM uses a device buffer with shape `(num_layers, max_num_batched_tokens, top_k)`: +- `buffer[layer_id]` gives a contiguous `(max_num_batched_tokens, top_k)` view +- This view is passed as `routing_replay_out` to the FlashInfer kernel +- The `(L, N, K)` layout ensures zero-copy per-layer slicing + +### Integration Pattern + +```python +# Pre-allocate once (during model initialization) +device_buffer = torch.zeros( + (num_layers, max_num_batched_tokens, top_k), + dtype=torch.int16, + device="cuda", +) + +# Per-layer forward pass +for layer_id, moe_layer in enumerate(moe_layers): + replay_slice = device_buffer[layer_id] # contiguous (N, K) view + output = trtllm_fp8_block_scale_moe( + ..., + routing_replay_out=replay_slice, + ) +``` + +### Validation + +```python +import torch + +# Allocate replay buffer +replay = torch.full((num_tokens, top_k), -1, device="cuda", dtype=torch.int16) + +# Run MoE +output = trtllm_fp8_block_scale_moe(..., routing_replay_out=replay) + +# Verify non-zero (not all -1 sentinel) +assert (replay != -1).any(), "Routing replay data is all sentinel values" +assert (replay >= 0).all() and (replay < num_experts).all(), "Invalid expert IDs" +``` diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index b164ce0ae7..847ce7b444 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1116,6 +1116,7 @@ def forward( [-1, -1] if tactic == -1 else tactic, self.activation_type, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) elif ( self.dtype_act == DtypeTrtllmGen.E4m3 @@ -1176,6 +1177,7 @@ def forward( self.fp8_quantization_type, self.activation_type, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) else: # FP8 per tensor scale @@ -1234,6 +1236,7 @@ def forward( output, [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) else: moe_op.trtllm_fp4_block_scale_moe( @@ -1270,11 +1273,12 @@ def forward( output, [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) @register_custom_op( "flashinfer::trtllm_bf16_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_bf16_moe_op( routing_logits: Optional[torch.Tensor], @@ -1300,6 +1304,7 @@ def trtllm_bf16_moe_op( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: assert routing_logits is not None or topk_ids is not None, ( "either routing_logits or topk_ids must be provided" @@ -1413,6 +1418,7 @@ def trtllm_bf16_moe_op( [-1, -1] if tactic == -1 else tactic, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: return [output] @@ -1448,7 +1454,9 @@ def _fake_trtllm_bf16_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1456,7 +1464,7 @@ def _fake_trtllm_bf16_moe( @register_custom_op( "flashinfer::trtllm_fp8_per_tensor_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_fp8_per_tensor_scale_moe_op( routing_logits: torch.Tensor, @@ -1482,6 +1490,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1617,7 +1626,9 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ): + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1625,7 +1636,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( @register_custom_op( "flashinfer::trtllm_fp8_block_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_fp8_block_scale_moe_op( routing_logits: Optional[torch.Tensor], @@ -1656,6 +1667,7 @@ def trtllm_fp8_block_scale_moe_op( fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: # Determine routing mode: compute from logits or use pre-computed if routing_logits is None: @@ -1800,6 +1812,7 @@ def trtllm_fp8_block_scale_moe_op( fp8_quantization_type, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -1845,7 +1858,9 @@ def _fake_trtllm_fp8_block_scale_moe( fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1853,7 +1868,7 @@ def _fake_trtllm_fp8_block_scale_moe( @register_custom_op( "flashinfer::trtllm_fp4_block_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_fp4_block_scale_moe_op( routing_logits: Optional[torch.Tensor], @@ -1889,6 +1904,7 @@ def trtllm_fp4_block_scale_moe_op( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: if routing_logits is None: assert topk_ids is not None, ( @@ -2028,6 +2044,7 @@ def trtllm_fp4_block_scale_moe_op( output, [-1, -1] if tactic == -1 else tactic, norm_topk_prob, + routing_replay_out, ) if do_finalize: return [output] @@ -2073,7 +2090,9 @@ def _fake_trtllm_fp4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ): + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] if output is None else output.shape[1] @@ -2081,7 +2100,7 @@ def _fake_trtllm_fp4_block_scale_moe( @register_custom_op( "flashinfer::trtllm_mxint4_block_scale_moe", - mutates_args=(""), + mutates_args=("routing_replay_out",), ) def trtllm_mxint4_block_scale_moe_op( routing_logits: torch.Tensor, @@ -2108,6 +2127,7 @@ def trtllm_mxint4_block_scale_moe_op( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: routing_dtype = routing_logits.dtype hidden_size = hidden_states.shape[-1] @@ -2212,6 +2232,7 @@ def trtllm_mxint4_block_scale_moe_op( output, [-1, -1] if tactic == -1 else tactic, norm_topk_prob, + routing_replay_out, ) if do_finalize: return [output] @@ -2248,7 +2269,9 @@ def _fake_trtllm_mxint4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ): + _ = routing_replay_out seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -2263,6 +2286,28 @@ def _fake_trtllm_mxint4_block_scale_moe( ) +def _validate_routing_replay_out( + routing_replay_out: Optional[torch.Tensor], top_k: int +) -> None: + """Validate routing_replay_out tensor properties before passing to C++ kernels.""" + if routing_replay_out is None: + return + if routing_replay_out.dtype != torch.int16: + raise ValueError( + f"routing_replay_out must be int16, got {routing_replay_out.dtype}" + ) + if routing_replay_out.ndim != 2: + raise ValueError( + f"routing_replay_out must be 2D [num_tokens, top_k], got {routing_replay_out.ndim}D" + ) + if routing_replay_out.shape[1] != top_k: + raise ValueError( + f"routing_replay_out dim1 must equal top_k={top_k}, got {routing_replay_out.shape[1]}" + ) + if not routing_replay_out.is_contiguous(): + raise ValueError("routing_replay_out must be contiguous (packed row-major)") + + @flashinfer_api def trtllm_bf16_moe( routing_logits: torch.Tensor, @@ -2286,6 +2331,7 @@ def trtllm_bf16_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2333,6 +2379,7 @@ def trtllm_bf16_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) result = get_trtllm_moe_sm100_module().trtllm_bf16_moe( routing_logits, routing_bias, @@ -2357,6 +2404,7 @@ def trtllm_bf16_moe( tune_max_num_tokens, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -2389,6 +2437,7 @@ def trtllm_bf16_routed_moe( enable_pdl: bool = True, tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2435,6 +2484,7 @@ def trtllm_bf16_routed_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, undefined, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) result = get_trtllm_moe_sm100_module().trtllm_bf16_moe( None, None, @@ -2459,6 +2509,7 @@ def trtllm_bf16_routed_moe( tune_max_num_tokens, activation_type, True, # norm_topk_prob: not used for pre-computed routing + routing_replay_out, ) if do_finalize: @@ -2495,6 +2546,7 @@ def trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 per tensor scale MoE operation. @@ -2533,6 +2585,7 @@ def trtllm_fp8_per_tensor_scale_moe( when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) result = get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, @@ -2557,6 +2610,7 @@ def trtllm_fp8_per_tensor_scale_moe( tune_max_num_tokens, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -2595,6 +2649,7 @@ def trtllm_fp8_block_scale_moe( fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = ActivationType.Swiglu.value, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 block scale MoE operation. @@ -2634,10 +2689,16 @@ def trtllm_fp8_block_scale_moe( - 4: Geglu - 6: Relu2 - 7: Identity + routing_replay_out (Optional[torch.Tensor]): Optional int16 output tensor of shape + (num_tokens_or_larger, top_k) to capture selected expert IDs during routing. + Column order matches topk_indices. When None (default), zero overhead - the + kernel skips the write entirely. Buffer may be larger than num_tokens for CUDA + graph pre-allocation; only rows [0, num_tokens) are written. Returns: when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) output = torch.empty( hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device ) @@ -2670,6 +2731,7 @@ def trtllm_fp8_block_scale_moe( fp8_quantization_type, activation_type, norm_topk_prob, + routing_replay_out, ) if do_finalize: @@ -2828,6 +2890,7 @@ def trtllm_fp4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -2892,6 +2955,7 @@ def trtllm_fp4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( routing_logits, None, @@ -2926,6 +2990,7 @@ def trtllm_fp4_block_scale_moe( output, tune_max_num_tokens, norm_topk_prob, + routing_replay_out, ) @@ -3090,6 +3155,7 @@ def trtllm_mxint4_block_scale_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """MxInt4 block scale MoE operation. @@ -3137,6 +3203,7 @@ def trtllm_mxint4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + _validate_routing_replay_out(routing_replay_out, top_k) return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe( routing_logits, routing_bias, @@ -3162,4 +3229,5 @@ def trtllm_mxint4_block_scale_moe( output, tune_max_num_tokens, norm_topk_prob, + routing_replay_out, ) diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py index 5e26ca30cf..773d2b7601 100644 --- a/flashinfer/fused_moe/fused_routing_dsv3.py +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -1,3 +1,5 @@ +from typing import Optional + from flashinfer.api_logging import flashinfer_api from flashinfer.jit import gen_dsv3_fused_routing_module import functools @@ -21,6 +23,7 @@ def _check_dsv3_fused_routing_supported( topk_values, topk_indices, launch_with_pdl, + routing_replay_out=None, ): """Validate configuration parameters for DSv3 fused routing kernel. @@ -38,6 +41,21 @@ def _check_dsv3_fused_routing_supported( Raises: ValueError: If configuration is invalid or exceeds kernel limits """ + if routing_replay_out is not None: + num_tokens = scores.shape[0] + if routing_replay_out.dtype != torch.int16: + raise ValueError( + f"routing_replay_out must be int16, got {routing_replay_out.dtype}" + ) + if ( + routing_replay_out.shape[0] < num_tokens + or routing_replay_out.shape[1] != topk + ): + raise ValueError( + f"routing_replay_out shape[0] must be >= {num_tokens} and shape[1] must be {topk}, " + f"got {tuple(routing_replay_out.shape)}" + ) + # Extract number of experts from scores shape num_experts = scores.shape[1] @@ -86,7 +104,7 @@ def get_dsv3_fused_routing_module(): @register_custom_op( "flashinfer::NoAuxTc", - mutates_args=["topk_values", "topk_indices"], + mutates_args=["topk_values", "topk_indices", "routing_replay_out"], ) def NoAuxTc( scores: torch.Tensor, @@ -98,6 +116,7 @@ def NoAuxTc( topk_values: torch.Tensor, topk_indices: torch.Tensor, launch_with_pdl: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> None: module.NoAuxTc( scores, @@ -109,6 +128,7 @@ def NoAuxTc( topk_values, topk_indices, launch_with_pdl, + routing_replay_out, ) return SimpleNamespace( @@ -128,6 +148,7 @@ def fused_topk_deepseek( topk_values: torch.Tensor, topk_indices: torch.Tensor, launch_with_pdl: bool = True, + routing_replay_out: Optional[torch.Tensor] = None, ) -> None: """Fused expert routing with top-k selection for DeepSeek-V3. @@ -168,6 +189,10 @@ def fused_topk_deepseek( This tensor is mutated in-place. launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent Device-side Launch. Defaults to True. + routing_replay_out (torch.Tensor, optional): Pre-allocated output tensor of shape + (num_tokens, topk) with dtype int16 for recording the selected expert IDs per + token. If None, no routing replay recording occurs (zero overhead). This tensor + is mutated in-place. Returns: None: Results are written directly to `topk_values` and `topk_indices` tensors. @@ -193,4 +218,5 @@ def fused_topk_deepseek( topk_values, topk_indices, launch_with_pdl, + routing_replay_out, ) diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index 86e9064a83..e16ee6600f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -103,6 +103,12 @@ struct DataBase { int32_t mLocalExpertsStartIdx; int32_t mLocalExpertsStrideLog2; int32_t mNumLocalExperts; + + // optional: if nullptr, no routing replay recording occurs + // dim: [mNumTokens, mTopK] + // Records the selected expert IDs per token for replay + // NOTE: placed at end of struct to preserve field offsets for existing routing kernels + int16_t* mPtrRoutingReplayOut{nullptr}; }; template @@ -138,6 +144,9 @@ struct KernelParamsBase { int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; + // NOTE: placed at end to preserve field offsets for existing routing kernels + int16_t* mPtrRoutingReplayOut = nullptr; + // Public initialization function - make it a template to accept different Data types template void setBaseParams(DataType const& data) { @@ -154,6 +163,7 @@ struct KernelParamsBase { mPtrTopKWeights = static_cast(data.mPtrTopKWeights); mPtrTopKIds = static_cast(data.mPtrTopKIds); mPtrScores = (InputT const*)data.mPtrScores; + mPtrRoutingReplayOut = data.mPtrRoutingReplayOut; mNumTokens = data.mNumTokens; mNumExperts = data.mNumExperts; diff --git a/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h index 5af8fe39db..fa7f7f8339 100644 --- a/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h +++ b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h @@ -28,6 +28,7 @@ template void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, - cudaStream_t const stream = 0); + bool const launch_with_pdl, cudaStream_t const stream, + int16_t* routing_replay_out = nullptr); } // namespace tensorrt_llm::kernels diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 2a6092e9b3..a3abfcaed8 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -137,7 +137,7 @@ class Runner { batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, batchedGemm::trtllm::gen::Dtype dtypeLogits, - bool normTopkProb = true); + bool normTopkProb = true, int16_t* routing_replay_out = nullptr); private: int32_t mTileTokensDim{8}; diff --git a/tests/model_optimizations/test_dsv3_fused_routing.py b/tests/model_optimizations/test_dsv3_fused_routing.py index e84c9ca884..f51a05a34d 100644 --- a/tests/model_optimizations/test_dsv3_fused_routing.py +++ b/tests/model_optimizations/test_dsv3_fused_routing.py @@ -499,3 +499,178 @@ def test_dsv3_fused_routing_op( # Validate values validate_values(ground_truth, sorted_vals, tokens_with_different_experts, data_type) + + +@pytest.mark.parametrize("num_tokens", [1, 8, 64]) +@pytest.mark.parametrize("num_experts", [256]) +@pytest.mark.parametrize("topk", [1, 4, 8]) +@pytest.mark.parametrize("n_group", [1, 8]) +@pytest.mark.parametrize("topk_group", [1, 4]) +@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float16]) +def test_routing_replay_out_extended( + num_tokens, num_experts, topk, n_group, topk_group, data_type +): + """ + Test that routing_replay_out records the same expert IDs as topk_indices. + + Extended parametrization covering larger token counts (8, 64). + """ + if topk_group * n_group < topk or topk_group > n_group: + pytest.skip("Invalid configuration") + if n_group > 1: + if ( + topk > 8 + or num_experts / n_group > 32 + or num_experts / n_group * topk_group > 128 + ): + pytest.skip("Exceeds kernel limits for n_group > 1") + else: + if num_experts > 384 or topk > 8: + pytest.skip("Exceeds kernel limits for n_group = 1") + + torch.manual_seed(42) + device = "cuda" + scores = torch.randn(num_tokens, num_experts, device=device, dtype=data_type) + bias = torch.randn(num_experts, device=device, dtype=data_type) + routed_scaling_factor = 1.0 + + topk_values = torch.empty(num_tokens, topk, device=device, dtype=data_type) + topk_indices = torch.zeros(num_tokens, topk, device=device, dtype=torch.int32) + routing_replay_out = torch.full( + (num_tokens, topk), -1, device=device, dtype=torch.int16 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl=True, + routing_replay_out=routing_replay_out, + ) + + # routing_replay_out should contain the same expert IDs as topk_indices (per token) + for t in range(num_tokens): + replay_set = set(routing_replay_out[t].tolist()) + indices_set = set(topk_indices[t].tolist()) + assert replay_set == indices_set, ( + f"Token {t}: routing_replay_out experts {replay_set} " + f"!= topk_indices experts {indices_set}" + ) + + # Verify None produces identical results (no side effects from replay) + topk_values_no_replay = torch.empty( + num_tokens, topk, device=device, dtype=data_type + ) + topk_indices_no_replay = torch.zeros( + num_tokens, topk, device=device, dtype=torch.int32 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values_no_replay, + topk_indices_no_replay, + launch_with_pdl=True, + routing_replay_out=None, + ) + + torch.testing.assert_close(topk_values, topk_values_no_replay) + torch.testing.assert_close(topk_indices, topk_indices_no_replay) + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32]) +@pytest.mark.parametrize("num_experts", [256]) +@pytest.mark.parametrize("topk", [1, 4, 8]) +@pytest.mark.parametrize("n_group", [1, 8]) +@pytest.mark.parametrize("topk_group", [1, 4]) +@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float16]) +def test_routing_replay_out( + num_tokens, num_experts, topk, n_group, topk_group, data_type +): + """ + Test that routing_replay_out records the same expert IDs as topk_indices. + + The routing replay feature writes selected expert IDs (int16) into an + optional output tensor during the fused routing kernel. This test verifies + that routing_replay_out matches topk_indices (as sets per token), and that + passing None produces identical routing results (no side effects). + """ + if topk_group * n_group < topk or topk_group > n_group: + pytest.skip("Invalid configuration") + if n_group > 1: + if ( + topk > 8 + or num_experts / n_group > 32 + or num_experts / n_group * topk_group > 128 + ): + pytest.skip("Exceeds kernel limits for n_group > 1") + else: + if num_experts > 384 or topk > 8: + pytest.skip("Exceeds kernel limits for n_group = 1") + + torch.manual_seed(42) + device = "cuda" + scores = torch.randn(num_tokens, num_experts, device=device, dtype=data_type) + bias = torch.randn(num_experts, device=device, dtype=data_type) + routed_scaling_factor = 1.0 + + topk_values = torch.empty(num_tokens, topk, device=device, dtype=data_type) + topk_indices = torch.zeros(num_tokens, topk, device=device, dtype=torch.int32) + routing_replay_out = torch.full( + (num_tokens, topk), -1, device=device, dtype=torch.int16 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl=True, + routing_replay_out=routing_replay_out, + ) + + # routing_replay_out should contain the same expert IDs as topk_indices (per token) + for t in range(num_tokens): + replay_set = set(routing_replay_out[t].tolist()) + indices_set = set(topk_indices[t].tolist()) + assert replay_set == indices_set, ( + f"Token {t}: routing_replay_out experts {replay_set} " + f"!= topk_indices experts {indices_set}" + ) + + # Verify None produces identical results (no side effects from replay) + topk_values_no_replay = torch.empty( + num_tokens, topk, device=device, dtype=data_type + ) + topk_indices_no_replay = torch.zeros( + num_tokens, topk, device=device, dtype=torch.int32 + ) + + fused_topk_deepseek( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values_no_replay, + topk_indices_no_replay, + launch_with_pdl=True, + routing_replay_out=None, + ) + + torch.testing.assert_close(topk_values, topk_values_no_replay) + torch.testing.assert_close(topk_indices, topk_indices_no_replay) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index f5fd5cc263..e3c74aa20d 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -705,3 +705,153 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int): close = torch.isclose(output_ref, output_routed, atol=1e-2, rtol=1e-2) mismatch_pct = (~close).float().mean().item() * 100 assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [16]) +@pytest.mark.parametrize("top_k", [2, 4]) +def test_fp8_block_scale_moe_routing_replay( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + num_experts: int, +): + """Test that routing_replay_out in trtllm_fp8_block_scale_moe records correct expert IDs. + + Uses DeepSeekV3 routing (the only routing method with replay support). + Runs the full MoE kernel twice with the same inputs: once with routing_replay_out + and once without. Verifies that: + 1. The MoE output is identical (replay has no side effects). + 2. The replay buffer matches the reference routing result (sorted set equality). + 3. Tail rows beyond num_tokens remain sentinel (CUDA graph pre-alloc contract). + """ + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + n_group = 4 + topk_group = 2 + if topk_group * n_group < top_k or topk_group > n_group: + pytest.skip("Invalid DeepSeek routing configuration") + torch.manual_seed(42) + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + + routing_logits = torch.rand( + num_tokens, num_experts, device=device, dtype=torch.float32 + ) + routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) + + hidden_states_bf16 = ( + torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 + ) + hidden_states = hidden_states_bf16.to(torch.float8_e4m3fn) + + hidden_states_scale = torch.ones( + hidden_size // 128, num_tokens, device=device, dtype=torch.float32 + ) + + gemm1_weights = torch.randn( + num_experts, 2 * intermediate_size, hidden_size, device=device + ).to(torch.float8_e4m3fn) + gemm2_weights = torch.randn( + num_experts, hidden_size, intermediate_size, device=device + ).to(torch.float8_e4m3fn) + + gemm1_weights_scale = torch.ones( + num_experts, + 2 * intermediate_size // 128, + hidden_size // 128, + device=device, + dtype=torch.float32, + ) + gemm2_weights_scale = torch.ones( + num_experts, + hidden_size // 128, + intermediate_size // 128, + device=device, + dtype=torch.float32, + ) + + # Allocate oversized buffer to validate CUDA graph pre-allocation contract: + # kernel should only write to [0, num_tokens) and leave tail rows as sentinel. + replay_capacity = num_tokens + 5 + routing_replay_out = torch.full( + (replay_capacity, top_k), -1, device=device, dtype=torch.int16 + ) + + output_with_replay = trtllm_fp8_block_scale_moe( + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + RoutingMethodType.DeepSeekV3.value, + False, # use_shuffled_weight + 0, # weight_layout + enable_pdl, + routing_replay_out=routing_replay_out, + ) + + output_without_replay = trtllm_fp8_block_scale_moe( + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + 0, + num_experts, + 1.0, + RoutingMethodType.DeepSeekV3.value, + False, + 0, + enable_pdl, + routing_replay_out=None, + ) + + # MoE output should be identical regardless of replay + torch.testing.assert_close( + output_with_replay.to(torch.float), + output_without_replay.to(torch.float), + rtol=0, + atol=0, + ) + + # Compare replay against reference routing — verify active rows only + active_replay = routing_replay_out[:num_tokens] + # Verify replay IDs are valid expert indices + assert (active_replay >= 0).all() and (active_replay < num_experts).all(), ( + "Replay contains out-of-range expert IDs" + ) + # Each token should have top_k unique experts + for t in range(num_tokens): + unique_experts = active_replay[t].unique() + assert unique_experts.numel() == top_k, ( + f"Token {t}: expected {top_k} unique experts, got {unique_experts.numel()}" + ) + + # Tail rows beyond num_tokens should remain sentinel (-1) + assert (routing_replay_out[num_tokens:] == -1).all(), ( + "Kernel should not write beyond active token rows" + )