Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ def forward(
output,
[-1, -1] if tactic == -1 else tactic,
kwargs.get("norm_topk_prob", True),
kwargs.get("routing_replay_out"),
)
else:
moe_op.trtllm_fp4_block_scale_moe(
Expand Down Expand Up @@ -1272,6 +1273,7 @@ def forward(
output,
[-1, -1] if tactic == -1 else tactic,
kwargs.get("norm_topk_prob", True),
kwargs.get("routing_replay_out"),
)

@register_custom_op(
Expand Down
15 changes: 9 additions & 6 deletions include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ struct DataBase {
// Together with mPtrTopKWeights, they form the top-k experts for each token
int32_t* mPtrTopKIds{nullptr};

// optional: if nullptr, no routing replay recording occurs
// dim: [mNumTokens, mTopK]
// Records the selected expert IDs per token for replay
int16_t* mPtrRoutingReplayOut{nullptr};

// optional: if `nullptr`, scores are used directly as input.
// If it is given, it must represent a packed value s.t. the most significant
// 16/32 bits represent the score without sigmoid activation and
Expand Down Expand Up @@ -108,6 +103,12 @@ struct DataBase {
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;

// optional: if nullptr, no routing replay recording occurs
// dim: [mNumTokens, mTopK]
// Records the selected expert IDs per token for replay
// NOTE: placed at end of struct to preserve field offsets for existing routing kernels
int16_t* mPtrRoutingReplayOut{nullptr};
};

template <typename InputT_, typename OutputT_, int MaxNumExperts_, int MaxNumTopExperts_>
Expand All @@ -132,7 +133,6 @@ struct KernelParamsBase {
OutputT* mPtrTopKWeights = nullptr;
int32_t* mPtrTopKIds = nullptr;
InputT const* mPtrScores = nullptr;
int16_t* mPtrRoutingReplayOut = nullptr;

// Public scalar members
int32_t mNumTokens = 0;
Expand All @@ -144,6 +144,9 @@ struct KernelParamsBase {
int32_t mLocalExpertsStrideLog2 = 0;
int32_t mNumLocalExperts = 0;

// NOTE: placed at end to preserve field offsets for existing routing kernels
int16_t* mPtrRoutingReplayOut = nullptr;

// Public initialization function - make it a template to accept different Data types
template <typename DataType>
void setBaseParams(DataType const& data) {
Expand Down
Loading