Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
10d80fc
feat: add routing_replay_out to MoE kernel launchers and routing kernels
Apr 9, 2026
b39aa06
feat: add routing_replay_out to noAuxTc DSV3 routing kernel
Apr 9, 2026
fd6ca4b
feat: add routing_replay_out to Python MoE API
Apr 9, 2026
f6bd8cb
test: add routing replay tests for FP8 and DSV3 MoE kernels
Apr 9, 2026
8d0d0f0
docs: add vLLM routing replay integration guide
Apr 9, 2026
88e18b4
fix: add missing int16 dtype validation and FP4/MXINT4 public API params
Apr 9, 2026
1fb8454
fix: address code review feedback β€” missing routing_replay_out params…
TomerBN-Nvidia Apr 12, 2026
777a82d
fix: clarify unordered column semantics in doc and strengthen replay …
TomerBN-Nvidia Apr 12, 2026
5929e15
docs: clarify API list is vLLM-specific subset, add language tag to c…
TomerBN-Nvidia Apr 12, 2026
7e076e6
fix: add routing_replay_out to FP8 per-tensor op/fake_op, test oversi…
TomerBN-Nvidia Apr 12, 2026
ec2d137
fix: add __version__ to flashinfer_cubin for nightly base compat
Apr 9, 2026
c067939
fix: wrap DLDataType brace initializer in extra parens for TVM_FFI_IC…
Apr 12, 2026
72329b7
fix: address human review β€” revert __version__, docstring, trtllm-onl…
TomerBN-Nvidia Apr 13, 2026
500e022
fix: add missing Optional import to tvm_ffi_utils.h for noAuxTcKernel…
Apr 13, 2026
4e62412
fix: switch FP8 block scale replay test from Renormalize to DeepSeekV…
Apr 13, 2026
66cc35d
fix: column order matches topk_indices, remove em dash
Apr 13, 2026
26ad7a4
fix: reject strided routing_replay_out views - require contiguous layout
Apr 13, 2026
e135197
fix: move Optional import to noAuxTcKernels, add Python validation, s…
TomerBN-Nvidia Apr 13, 2026
06cefff
Merge branch 'main' into upstream-routing-replay
aleozlx Apr 13, 2026
697afe4
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
aleozlx Apr 14, 2026
67cc957
precommit
aleozlx Apr 14, 2026
625442f
Merge branch 'main' into upstream-routing-replay
TomerBN-Nvidia Apr 14, 2026
79b6c3b
fix: move mPtrRoutingReplayOut to end of routing structs
Apr 14, 2026
db1855e
Merge pull request #4 from TomerBN-Nvidia/fix-routing-struct-layout
TomerBN-Nvidia Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions csrc/fused_moe/noAuxTcKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "tensorrt_llm/common/envUtils.h"
#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;

namespace cg = cooperative_groups;
using namespace tensorrt_llm::common;

Expand All @@ -30,7 +32,8 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
int64_t const numGroup, int64_t const topkGroup,
int64_t const topk, int64_t const numExperts,
int64_t const numExpertsPerGroup,
double const routedScalingFactor) {
double const routedScalingFactor,
int16_t* routingReplayOut) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
Expand Down Expand Up @@ -212,6 +215,10 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
if (laneIdx < topk) {
topkValues[laneIdx] = static_cast<OutputT>(finalScore);
topkIndices[laneIdx] = expertIdx;
// Routing replay: record selected expert IDs per token
if (routingReplayOut != nullptr) {
routingReplayOut[blockIdx.x * topk + laneIdx] = static_cast<int16_t>(expertIdx);
}
}
}

Expand All @@ -224,7 +231,8 @@ template <typename InputT, typename BiasT, typename OutputT, typename IdxT>
void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices,
int64_t const num_tokens, int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk, double const routed_scaling_factor,
bool const launch_with_pdl, cudaStream_t const stream) {
bool const launch_with_pdl, cudaStream_t const stream,
int16_t* routing_replay_out) {
// Check if we can use the optimized deepseek_v3_topk_kernel
bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts);

Expand Down Expand Up @@ -262,7 +270,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk

cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
routed_scaling_factor);
routed_scaling_factor, routing_replay_out);
sync_check_cuda_error(stream);
} else {
// TODO: call the generic path (previous implementation) or signal unsupported config.
Expand All @@ -279,7 +287,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \
int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \
int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \
bool const launch_with_pdl, cudaStream_t const stream);
bool const launch_with_pdl, cudaStream_t const stream, int16_t* routing_replay_out);

INSTANTIATE_NOAUX_TC(float, float, float, int32_t);
INSTANTIATE_NOAUX_TC(float, half, float, int32_t);
Expand All @@ -305,7 +313,7 @@ namespace flashinfer::trtllm_dsv3_fused_routing {

void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_group, int64_t topk,
double routed_scaling_factor, TensorView topk_values, TensorView topk_indices,
bool launch_with_pdl) {
bool launch_with_pdl, Optional<TensorView> routing_replay_out) {
auto data_type = scores.dtype();
auto bias_type = bias.dtype();

Expand Down Expand Up @@ -342,6 +350,26 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
<< "topk_indices must have the same dtype as scores";

// Validate and extract routing_replay_out
// NOTE: dim0 >= num_tokens is intentionally NOT checked β€” with CUDA graphs the buffer
// is pre-allocated at maximum batch size and reused across steps with varying num_tokens.
// The kernel only writes to indices [0, num_tokens), so a larger buffer is safe.
constexpr int64_t int16_code_val = encode_dlpack_dtype(DLDataType{kDLInt, 16, 1});
int16_t* replay_ptr = nullptr;
if (routing_replay_out.has_value()) {
auto replay = routing_replay_out.value();
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
<< "routing_replay_out must be a CUDA tensor";
TVM_FFI_ICHECK(replay.device().device_id == scores.device().device_id)
<< "routing_replay_out must be on the same device as scores";
TVM_FFI_ICHECK(replay.ndim() == 2)
<< "routing_replay_out must be a 2D Tensor [num_tokens, topk]";
TVM_FFI_ICHECK(replay.sizes()[1] == topk) << "routing_replay_out dim1 must equal topk";
TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code_val)
<< "routing_replay_out must be int16 dtype";
replay_ptr = reinterpret_cast<int16_t*>(replay.data_ptr());
}
Comment thread
TomerBN-Nvidia marked this conversation as resolved.

auto stream = get_stream(scores.device());
using namespace tensorrt_llm::kernels;
switch (encode_dlpack_dtype(data_type)) {
Expand All @@ -353,22 +381,22 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
reinterpret_cast<half*>(scores.data_ptr()), reinterpret_cast<half*>(bias.data_ptr()),
reinterpret_cast<half*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float32_code:
invokeNoAuxTc<half, float, half, int32_t>(
reinterpret_cast<half*>(scores.data_ptr()), reinterpret_cast<float*>(bias.data_ptr()),
reinterpret_cast<half*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case bfloat16_code:
invokeNoAuxTc<half, __nv_bfloat16, half, int32_t>(
reinterpret_cast<half*>(scores.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()),
reinterpret_cast<half*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
default:
throw std::invalid_argument(
Expand All @@ -384,22 +412,22 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
reinterpret_cast<float*>(bias.data_ptr()),
reinterpret_cast<float*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float16_code:
invokeNoAuxTc<float, half, float, int32_t>(
reinterpret_cast<float*>(scores.data_ptr()), reinterpret_cast<half*>(bias.data_ptr()),
reinterpret_cast<float*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case bfloat16_code:
invokeNoAuxTc<float, __nv_bfloat16, float, int32_t>(
reinterpret_cast<float*>(scores.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()),
reinterpret_cast<float*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
default:
throw std::invalid_argument(
Expand All @@ -416,23 +444,23 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float16_code:
invokeNoAuxTc<__nv_bfloat16, half, __nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()),
reinterpret_cast<half*>(bias.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float32_code:
invokeNoAuxTc<__nv_bfloat16, float, __nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()),
reinterpret_cast<float*>(bias.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
default:
throw std::invalid_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ __global__ void routingMainKernel(KernelParams params) {
params.mPtrTopKIds == nullptr) {
params.mPtrTopKWeights[idxTopK] = finalScore;
}

// Routing replay: record all top-K selected expert IDs per token.
// Layout: [num_tokens, topK] -- same indexing as mPtrTopKPacked.
if (params.mPtrRoutingReplayOut != nullptr && laneIdx < params.mTopK) {
params.mPtrRoutingReplayOut[idxTopK] = static_cast<int16_t>(expertIdx);
}
}
}

Expand Down
Loading
Loading