diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 80a464f51d..847ce7b444 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1236,6 +1236,7 @@ def forward( output, [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) else: moe_op.trtllm_fp4_block_scale_moe( @@ -1272,6 +1273,7 @@ def forward( output, [-1, -1] if tactic == -1 else tactic, kwargs.get("norm_topk_prob", True), + kwargs.get("routing_replay_out"), ) @register_custom_op( diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index 0ebab0143d..e16ee6600f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -69,11 +69,6 @@ struct DataBase { // Together with mPtrTopKWeights, they form the top-k experts for each token int32_t* mPtrTopKIds{nullptr}; - // optional: if nullptr, no routing replay recording occurs - // dim: [mNumTokens, mTopK] - // Records the selected expert IDs per token for replay - int16_t* mPtrRoutingReplayOut{nullptr}; - // optional: if `nullptr`, scores are used directly as input. // If it is given, it must represent a packed value s.t. the most significant // 16/32 bits represent the score without sigmoid activation and @@ -108,6 +103,12 @@ struct DataBase { int32_t mLocalExpertsStartIdx; int32_t mLocalExpertsStrideLog2; int32_t mNumLocalExperts; + + // optional: if nullptr, no routing replay recording occurs + // dim: [mNumTokens, mTopK] + // Records the selected expert IDs per token for replay + // NOTE: placed at end of struct to preserve field offsets for existing routing kernels + int16_t* mPtrRoutingReplayOut{nullptr}; }; template @@ -132,7 +133,6 @@ struct KernelParamsBase { OutputT* mPtrTopKWeights = nullptr; int32_t* mPtrTopKIds = nullptr; InputT const* mPtrScores = nullptr; - int16_t* mPtrRoutingReplayOut = nullptr; // Public scalar members int32_t mNumTokens = 0; @@ -144,6 +144,9 @@ struct KernelParamsBase { int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; + // NOTE: placed at end to preserve field offsets for existing routing kernels + int16_t* mPtrRoutingReplayOut = nullptr; + // Public initialization function - make it a template to accept different Data types template void setBaseParams(DataType const& data) {