diff --git a/csrc/fused_moe/moeTopKFuncs.cuh b/csrc/fused_moe/moeTopKFuncs.cuh new file mode 100644 index 0000000000..e34c5f2665 --- /dev/null +++ b/csrc/fused_moe/moeTopKFuncs.cuh @@ -0,0 +1,254 @@ + +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifndef TRTLLM_MOETOPKFUNCS_CUH_H +#define TRTLLM_MOETOPKFUNCS_CUH_H + +#include +#include + +#include + +#include "flashinfer/arch_condition.h" + +namespace tensorrt_llm::kernels { + +namespace reduce_topk { +namespace cg = cooperative_groups; +static constexpr int kWARP_SIZE = 32; +static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = flashinfer::arch::is_major_v<10>; + +template +struct TopKRedType { + using T = T_; + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Top K reduction only implemented for int, float, float16 and bfloat16"); + + using TypeCmp = std::conditional_t; + using IdxT = std::conditional_t; + + static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16; + static constexpr int kMaxIdx = 65535; + TypeCmp compValIdx; + + static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) { + auto valueBits = + cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); + TypeCmp compactTmp = valueBits; + compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); + // Use 65535 minus idx to give higher priority to elements with smaller indices. + return compactTmp; + } + + static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) { + // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the + // lower 16 bits + index = kMaxIdx - static_cast((cmp & 0xFFFF)); + + auto compactTmp = cmp >> kMoveBits; + auto valueBits = cub::Traits::TwiddleOut( + reinterpret_cast::UnsignedBits&>(compactTmp)); + value = reinterpret_cast(valueBits); + } + + __host__ __device__ TopKRedType() = default; + + __host__ __device__ TopKRedType(T val, int32_t idx) : compValIdx(makeCmpVal(val, idx)) {} + + __host__ __device__ operator TypeCmp() const noexcept { return compValIdx; } + + __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) { + if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) { + return cg::reduce(warp, compValIdx, cg::greater{}); + } else { + TypeCmp result; + asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); + return result; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TopKIdx { + // by default, empty +}; + +template +struct TopKIdx { + static constexpr int K = K_; + int32_t val[K]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ + auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ + topK[I].compValIdx = pairMax; \ + topK[J].compValIdx = pairMin; \ + } + +template +struct Sort; + +template +struct Sort<1, RedType> { + static __device__ void run(RedType* topK) {} +}; + +template +struct Sort<2, RedType> { + static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); } +}; + +template +struct Sort<3, RedType> { + static __device__ void run(RedType* topK) { + TOPK_SWAP(0, 1); + TOPK_SWAP(1, 2); + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<4, RedType> { + static __device__ void run(RedType* topK) { + TOPK_SWAP(0, 2); + TOPK_SWAP(1, 3); + TOPK_SWAP(0, 1); + TOPK_SWAP(2, 3); + TOPK_SWAP(1, 2); + } +}; + +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type value, + int32_t idx, Type const minValue, int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + using RedType = TopKRedType; + RedType topK{value, idx}; + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct + { + topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK; + // get the next largest value + packedMax = topK.reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], + Type minValue, int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N < 5, "Only support candidates number less than or equal to 128"); + using RedType = TopKRedType; + RedType topK[N]; +#pragma unroll + for (int nn = 0; nn < N; ++nn) { + topK[nn] = RedType{value[nn], idx[nn]}; + } + + if constexpr (!IsSorted) { + Sort::run(topK); + } + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < actualK; ++kk) { + bool update = kk > 0 && packedMax == topK[0].compValIdx; +#pragma unroll + for (int nn = 0; nn < N; ++nn) { + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} + : update ? topK[nn + 1] + : topK[nn]; + } + // get the next largest value + packedMax = topK[0].reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], + int32_t (&idx)[N], Type const minValue, + int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); + static_assert( + N <= 4 || N % 4 == 0, + "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4"); + using RedType = TopKRedType; + + if constexpr (N <= 4) { + reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); + } else { + constexpr int numLoops = N / 4; + constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1; + + Type topKBufferValue[numResults]; + int32_t topKBufferIdx[numResults]; + int32_t laneIdx = threadIdx.x % kWARP_SIZE; + + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct + } + for (int loop = 0; loop < numLoops; ++loop) { + int start = loop * 4; + Type topKValue[K]; + int32_t topKIdx[K]; + Type inValue[4]; + int32_t inIdx[4]; + for (int i = 0; i < 4; ++i) { + inValue[i] = value[start + i]; + inIdx[i] = idx[start + i]; + } + reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); + int inOffset = laneIdx % K; + if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { + topKBufferValue[0] = topKValue[inOffset]; + topKBufferIdx[0] = topKIdx[inOffset]; + } + if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) { + topKBufferValue[1] = topKValue[inOffset]; + topKBufferIdx[1] = topKIdx[inOffset]; + } + } + + reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, + actualK); + } +}; + +#undef TOPK_SWAP + +} // namespace reduce_topk +} // namespace tensorrt_llm::kernels +#endif // TRTLLM_MOETOPKFUNCS_CUH_H diff --git a/csrc/fused_moe/noAuxTcKernels.cu b/csrc/fused_moe/noAuxTcKernels.cu new file mode 100644 index 0000000000..1f57d9b57b --- /dev/null +++ b/csrc/fused_moe/noAuxTcKernels.cu @@ -0,0 +1,450 @@ +#include +#include + +#include + +#include "flashinfer/trtllm/fused_moe/noAuxTcKernels.h" +#include "moeTopKFuncs.cuh" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/envUtils.h" +#include "tvm_ffi_utils.h" + +namespace cg = cooperative_groups; +using namespace tensorrt_llm::common; + +namespace tensorrt_llm::kernels { +static constexpr int WARP_SIZE = 32; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopExperts = 8; +static constexpr int MaxNumTopGroups = 4; + +static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; } + +template +__global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, + BiasT* routingBias, int64_t const numTokens, + int64_t const numGroup, int64_t const topkGroup, + int64_t const topk, int64_t const numExperts, + int64_t const numExpertsPerGroup, + double const routedScalingFactor) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts]; + // number of expert groups is bounded by number of warps + int constexpr NumWarps = MaxNumExperts / WARP_SIZE; + __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WARP_SIZE; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); + + if constexpr (UseGroups) { + if (warpIdx >= numGroup) { + return; + } + } + + // note that for invalid scores, we simply use a negative value: + // they work well even with the compacted format used in topK, and + // sigmoid / bias activated scores cannot be negative + static constexpr float invalidScoreFloat = float{-INFINITY}; + const OutputT invalidScore = OutputT{invalidScoreFloat}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < numExperts; + if constexpr (UseGroups) { + threadExpert = warpIdx * numExpertsPerGroup + laneIdx; + expertSelected = laneIdx < numExpertsPerGroup; + } + + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; + auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) : invalidScoreFloat; + topkValues += blockIdx.x * topk; + topkIndices += blockIdx.x * topk; + + // get our assigned thread score; each warp represents one expert group + float score = expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of numGroup + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[MaxNumTopExperts]; // bound of topk + int32_t topExperts[MaxNumTopExperts]; + + if constexpr (UseGroups) { + reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + /* minValue */ invalidScoreFloat); + + // get the final group score and write it to shared + if (laneIdx == 0) { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + if constexpr (UseGroups) { + if (warpIdx == 0) { + // a single warp performs the selection of top groups, and goes on to select the final experts + float groupScore = laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat; + + reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + + // final expert selection: get relevant indexes and scores from shared + +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup + // auto groupIdx = topGroupIdx[ii]; + auto groupIdx = (ii < topkGroup) ? topGroupIdx[ii] : 0; + expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx; + + expertScoreGroup[ii] = (ii < topkGroup) && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + tensorrt_llm::kernels::reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, + expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + } + } else if constexpr (MaxNumExperts > MaxNumExpertsUnit) { + // without groups, and the expert number is larger than MaxNumExpertsUnit, + // we need to use multiple warps to calculate the intermediate topk results + + int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = + offset + expertIdx < numExperts ? smemScoreBias[offset + expertIdx] : invalidScoreFloat; + } + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + + if (laneIdx < topk) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } + } + __syncthreads(); + if (warpIdx == 0) { + int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1; + float intermidiateScore[NumInterTopKPerThread]; + int32_t intermidiateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { + int ii = i / WARP_SIZE; + if (i < NumInterTopK) { + intermidiateScore[ii] = smemInterTopScores[i]; + intermidiateExpert[ii] = smemInterTopExperts[i]; + } else { + intermidiateScore[ii] = invalidScoreFloat; + intermidiateExpert[ii] = MaxNumExperts - 1; + } + } + reduce_topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + /* minValue */ invalidScoreFloat, topk); + } + } else { + // without groups, and the expert number is smaller than MaxNumExpertsUnit + // each thread just takes `MaxNumTopGroups` experts + if (warpIdx == 0) { +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = + expertIdx < numExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + } + } + + if (warpIdx == 0) { + // determine our lane's expert index and write to output + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + // norm the value + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); + // store the topk scores and experts to output + if (laneIdx < topk) { + topkValues[laneIdx] = static_cast(finalScore); + topkIndices[laneIdx] = expertIdx; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, + bool const launch_with_pdl, cudaStream_t const stream) { + // Check if we can use the optimized deepseek_v3_topk_kernel + bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts); + + int64_t const experts_per_group = num_experts / n_group; + bool const is_multi_group = (n_group != 1) && (num_experts <= NumDeepseekExperts) && + (experts_per_group <= WARP_SIZE) && + (experts_per_group * topk_group <= MaxNumExpertsUnit); + + if (is_single_group || is_multi_group) { + cudaLaunchConfig_t config; + auto* kernel_instance = + &deepseek_v3_topk_kernel; + int num_threads = NumDeepseekExperts; + if (is_single_group) { + if (num_experts > MaxNumExpertsUnit) { + kernel_instance = + &deepseek_v3_topk_kernel; + num_threads = NumKimiK2Experts; + } else { + kernel_instance = + &deepseek_v3_topk_kernel; + num_threads = MaxNumExpertsUnit; + } + } + + config.gridDim = num_tokens; + config.blockDim = num_threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = launch_with_pdl; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, + routed_scaling_factor); + sync_check_cuda_error(stream); + } else { + // TODO: call the generic path (previous implementation) or signal unsupported config. + TLLM_CHECK_WITH_INFO(false, + "invokeNoAuxTc: unsupported configuration (n_group=%ld, num_experts=%ld, " + "topk_group=%ld). Please use " + "original pytorch implementation.", + n_group, num_experts, topk_group); + } +} + +#define INSTANTIATE_NOAUX_TC(InputT, BiasT, OutputT, IdxT) \ + template void invokeNoAuxTc( \ + InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \ + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \ + bool const launch_with_pdl, cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, float, float, int32_t); +INSTANTIATE_NOAUX_TC(float, half, float, int32_t); + +INSTANTIATE_NOAUX_TC(half, float, half, int32_t); +INSTANTIATE_NOAUX_TC(half, half, half, int32_t); + +#ifdef ENABLE_BF16 +INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, float, int32_t); +INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, half, int32_t); + +INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, __nv_bfloat16, int32_t); +#endif + +} // namespace tensorrt_llm::kernels + +namespace flashinfer::trtllm_dsv3_fused_routing { +// th::Tensor const& scores, th::Tensor const& bias, int64_t n_group, +// int64_t topk_group, int64_t topk, double routed_scaling_factor +// th::Tensor topk_values, th::Tensor topk_indices + +void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_group, int64_t topk, + double routed_scaling_factor, TensorView topk_values, TensorView topk_indices, + bool launch_with_pdl) { + auto data_type = scores.dtype(); + auto bias_type = bias.dtype(); + + auto input_size = scores.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + + TVM_FFI_ICHECK(input_size.size() == 2) << "scores must be a 2D Tensor"; + TVM_FFI_ICHECK((scores.device().device_type == kDLCUDA) && (bias.device().device_type == kDLCUDA)) + << "scores and bias must be CUDA tensors"; + TVM_FFI_ICHECK(scores.device().device_id == bias.device().device_id) + << "scores and bias must be on the same device"; + TVM_FFI_ICHECK(bias.dim() == 1 && bias.numel() == num_experts) + << "bias must be 1D with length == number of experts (%ld)"; + TVM_FFI_ICHECK(num_experts % n_group == 0) << "num_experts should be divisible by n_group"; + TVM_FFI_ICHECK(n_group <= 32) + << "n_group should be smaller than or equal to 32 for now"; //@todo: remove this restriction + // later + TVM_FFI_ICHECK(topk <= 32) + << "topk should be smaller than or equal to 32 for now"; //@todo: remove this restriction + // later + TVM_FFI_ICHECK(topk_values.dim() == 2) << "topk_values must be a 2D Tensor"; + TVM_FFI_ICHECK(topk_indices.dim() == 2) << "topk_indices must be a 2D Tensor"; + TVM_FFI_ICHECK(topk_values.sizes()[0] == num_tokens) + << "topk_values must have the same number of tokens as scores"; + TVM_FFI_ICHECK(topk_indices.sizes()[0] == num_tokens) + << "topk_indices must have the same number of tokens as scores"; + TVM_FFI_ICHECK(topk_values.sizes()[1] == topk) + << "topk_values must have the same number of topk as scores"; + TVM_FFI_ICHECK(topk_indices.sizes()[1] == topk) + << "topk_indices must have the same number of topk as scores"; + TVM_FFI_ICHECK(topk_values.dtype() == data_type) + << "topk_values must have the same dtype as scores"; + TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) + << "topk_indices must have the same dtype as scores"; + + auto stream = get_stream(scores.device()); + using namespace tensorrt_llm::kernels; + switch (encode_dlpack_dtype(data_type)) { + case float16_code: + // Handle Float16 + switch (encode_dlpack_dtype(bias_type)) { + case float16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float32_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case bfloat16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + default: + throw std::invalid_argument( + "Invalid bias dtype, only supports float16, float32, and bfloat16"); + break; + } + break; + case float32_code: + switch (encode_dlpack_dtype(bias_type)) { + case float32_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case bfloat16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + default: + throw std::invalid_argument( + "Invalid bias dtype, only supports float16, float32, and bfloat16"); + break; + } + break; + case bfloat16_code: + // Handle BFloat16 + switch (encode_dlpack_dtype(bias_type)) { + case bfloat16_code: + invokeNoAuxTc<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float16_code: + invokeNoAuxTc<__nv_bfloat16, half, __nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float32_code: + invokeNoAuxTc<__nv_bfloat16, float, __nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + default: + throw std::invalid_argument( + "Invalid bias dtype, only supports bfloat16, float16, and float32"); + break; + } + break; + default: + // Handle other data types + throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } +} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(NoAuxTc, flashinfer::trtllm_dsv3_fused_routing::NoAuxTc); +} // namespace flashinfer::trtllm_dsv3_fused_routing diff --git a/flashinfer/dsv3_ops/__init__.py b/flashinfer/dsv3_ops/__init__.py index 49fb43b3ec..05a7c4e657 100644 --- a/flashinfer/dsv3_ops/__init__.py +++ b/flashinfer/dsv3_ops/__init__.py @@ -1,5 +1,7 @@ from flashinfer.gemm import mm_M1_16_K7168_N256 +from flashinfer.fused_moe import NoAuxTc __all__ = [ "mm_M1_16_K7168_N256", + "NoAuxTc", ] diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 8121c99c0a..d899fc88dc 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -32,6 +32,10 @@ trtllm_bf16_moe, ) +from .fused_routing_dsv3 import ( # noqa: F401 + NoAuxTc as NoAuxTc, +) + __all__ = [ "RoutingMethodType", "GatedActType", @@ -48,4 +52,5 @@ "trtllm_fp4_block_scale_routed_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", + "NoAuxTc", ] diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py new file mode 100644 index 0000000000..bb12472272 --- /dev/null +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -0,0 +1,194 @@ +from flashinfer.jit import gen_dsv3_fused_routing_module +import functools +from types import SimpleNamespace +import torch +from flashinfer.utils import ( + register_custom_op, + supported_compute_capability, + backend_requirement, +) + + +@supported_compute_capability([89, 90, 100, 103, 120, 121]) +def _check_dsv3_fused_routing_supported( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl, +): + """Validate configuration parameters for DSv3 fused routing kernel. + + Args: + scores: Input routing scores tensor + bias: Per-expert routing bias tensor + n_group: Number of expert groups + topk_group: Number of top groups to select + topk: Number of top experts to select per token + routed_scaling_factor: Scaling factor for normalized weights + topk_values: Output tensor for normalized expert weights + topk_indices: Output tensor for selected expert indices + launch_with_pdl: Whether to use Persistent Device-side Launch + + Raises: + ValueError: If configuration is invalid or exceeds kernel limits + """ + # Extract number of experts from scores shape + num_experts = scores.shape[1] + + # Check basic configuration constraints + if topk_group * n_group < topk or topk_group > n_group: + raise ValueError( + f"Invalid configuration: topk_group * n_group ({topk_group * n_group}) must be >= topk ({topk}) " + f"and topk_group ({topk_group}) must be <= n_group ({n_group})" + ) + + # Check kernel limits based on number of groups + if n_group > 1: + experts_per_group = num_experts / n_group + max_experts_in_selected_groups = experts_per_group * topk_group + + if topk > 8: + raise ValueError( + f"Invalid configuration for n_group > 1: topk ({topk}) must be <= 8" + ) + if experts_per_group > 32: + raise ValueError( + f"Invalid configuration for n_group > 1: num_experts / n_group " + f"({experts_per_group}) must be <= 32" + ) + if max_experts_in_selected_groups > 128: + raise ValueError( + f"Invalid configuration for n_group > 1: num_experts / n_group * topk_group " + f"({max_experts_in_selected_groups}) must be <= 128" + ) + else: # n_group == 1 + if num_experts > 384: + raise ValueError( + f"Invalid configuration for n_group = 1: num_experts ({num_experts}) must be <= 384" + ) + if topk > 8: + raise ValueError( + f"Invalid configuration for n_group = 1: topk ({topk}) must be <= 8" + ) + + return True + + +@functools.cache +def get_dsv3_fused_routing_module(): + module = gen_dsv3_fused_routing_module().build_and_load() + + @register_custom_op( + "flashinfer::NoAuxTc", + mutates_args=["topk_values", "topk_indices"], + ) + def NoAuxTc( + scores: torch.Tensor, + bias: torch.Tensor, + n_group: int, + topk_group: int, + topk: int, + routed_scaling_factor: float, + topk_values: torch.Tensor, + topk_indices: torch.Tensor, + launch_with_pdl: bool = True, + ) -> None: + module.NoAuxTc( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl, + ) + + return SimpleNamespace( + NoAuxTc=NoAuxTc, + ) + + +@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported) +def NoAuxTc( + scores: torch.Tensor, + bias: torch.Tensor, + n_group: int, + topk_group: int, + topk: int, + routed_scaling_factor: float, + topk_values: torch.Tensor, + topk_indices: torch.Tensor, + launch_with_pdl: bool = True, +) -> None: + """Fused expert routing with top-k selection for DeepSeek-V3. + + This function performs a highly optimized fused routing operation specifically + designed for DeepSeek-V3's Mixture of Experts (MoE) architecture with grouped + expert routing and no auxiliary loss. It combines score computation, expert + selection, and normalization into a single kernel operation. + + The routing algorithm consists of the following steps: + 1. Compute biased scores: sigmoid(scores) + bias for each expert + 2. Group experts and compute group scores (sum of top-2 experts per group) + 3. Select top-k groups based on group scores + 4. From selected groups, select top-k experts based on biased scores + 5. Normalize selected expert weights: sigmoid_scores / sum(sigmoid_scores) * scale + + Args: + scores (torch.Tensor): Input routing scores of shape (num_tokens, num_experts). + The logits produced by the router network before activation. Supports + bfloat16, float16, or float32. + bias (torch.Tensor): Per-expert routing bias of shape (num_experts,). Added to + sigmoid-activated scores to produce biased scores for expert selection. + Must match the dtype of scores. + n_group (int): Number of expert groups. Experts are divided into groups for + hierarchical selection. Typical value is 8 for DeepSeek-V3 with 256 experts + (32 experts per group). + topk_group (int): Number of top groups to select. Must be <= n_group. Typical + value is 4, meaning the top 4 groups are selected from 8 groups. + topk (int): Number of top experts to select per token. Must be <= num_experts. + Typical value is 8, meaning 8 experts are routed per token. + routed_scaling_factor (float): Scaling factor applied to normalized expert + weights. The final output weights are: + sigmoid_scores / sum(sigmoid_scores) * routed_scaling_factor. + topk_values (torch.Tensor): Pre-allocated output tensor of shape + (num_tokens, topk) for the normalized expert weights. Must be float32. + This tensor is mutated in-place. + topk_indices (torch.Tensor): Pre-allocated output tensor of shape + (num_tokens, topk) for the selected expert indices. Must be int32 or int64. + This tensor is mutated in-place. + launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent + Device-side Launch. Defaults to True. + + Returns: + None: Results are written directly to `topk_values` and `topk_indices` tensors. + + Note: + - The kernel uses float32 internally for all computations to ensure numerical + precision, even when inputs are float16 or bfloat16. + - This implementation is optimized for Hopper (compute capability 90, 100), + Ada (compute capability 89), and Blackwell (compute capability 120, 121) + architectures. + - The "NoAux" prefix indicates this variant does not compute auxiliary losses + (e.g., load balancing loss) during routing. + - The "Tc" suffix indicates the use of Tensor Core optimizations in the + underlying CUDA kernel. + """ + get_dsv3_fused_routing_module().NoAuxTc( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl, + ) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index bc4132ec9c..360a12f535 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -79,6 +79,9 @@ from .dsv3_optimizations import ( gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, ) +from .dsv3_optimizations import ( + gen_dsv3_fused_routing_module as gen_dsv3_fused_routing_module, +) cuda_lib_path = os.environ.get( diff --git a/flashinfer/jit/dsv3_optimizations.py b/flashinfer/jit/dsv3_optimizations.py index 88be890699..9aa720fa59 100644 --- a/flashinfer/jit/dsv3_optimizations.py +++ b/flashinfer/jit/dsv3_optimizations.py @@ -9,3 +9,37 @@ def gen_dsv3_router_gemm_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "dsv3_router_gemm.cu", ], ) + + +def gen_dsv3_fused_routing_module() -> JitSpec: + return gen_jit_spec( + "dsv3_fused_routing", + [ + jit_env.FLASHINFER_CSRC_DIR / "fused_moe/noAuxTcKernels.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", + ], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "cutlass_extensions" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels", + ], + ) diff --git a/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h new file mode 100644 index 0000000000..5af8fe39db --- /dev/null +++ b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm::kernels { + +template +void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, + cudaStream_t const stream = 0); + +} // namespace tensorrt_llm::kernels diff --git a/tests/model_optimizations/test_dsv3_fused_routing.py b/tests/model_optimizations/test_dsv3_fused_routing.py new file mode 100644 index 0000000000..1749e94f46 --- /dev/null +++ b/tests/model_optimizations/test_dsv3_fused_routing.py @@ -0,0 +1,501 @@ +""" +Test for NoAuxTc (DSv3 Fused Routing) Kernel + +This test validates the NoAuxTc kernel against a reference implementation, +accounting for numerical precision and tie-breaking differences. + +================================================================================ +DSv3 ROUTING ALGORITHM +================================================================================ + +1. Compute: sigmoid(scores) + bias for each expert (biased scores) +2. Group experts and compute group scores (sum of top-2 experts per group) +3. Select top-k groups based on group scores +4. From selected groups, select top-k experts based on biased scores +5. Normalize selected experts: sigmoid_scores / sum(sigmoid_scores) * scale + +================================================================================ +VALIDATION LOGIC FLOW +================================================================================ + +The test performs TWO stages of validation for each token: + +STAGE 1: EXPERT SELECTION VALIDATION +------------------------------------- +Checks if the kernel selected the correct (or acceptably tied) experts. + +1. Are kernel_experts == ref_experts (same set)? + YES → ✅ VALID (status: "exact") + Continue to Stage 2 to validate output values + NO → Continue to step 2 + +2. Are kernel_groups == ref_groups (same groups selected)? + YES → Continue to step 3 (same groups, different experts) + NO → Continue to step 4 (different groups) + +3. SAME GROUPS, DIFFERENT EXPERTS + Check if the differing experts have tied scores: + - Compute score_diff = max(diff_expert_scores) - min(diff_expert_scores) + - If score_diff < expert_tie_threshold: + → ✅ VALID (status: "tied_experts") + - Else: + → ❌ INVALID (status: "score_mismatch") + +4. DIFFERENT GROUPS + a) Are the groups tied? + - Compute all group scores (sum of top-2 experts per group) + - Check if differing groups have similar scores + - If group_score_diff < group_tie_threshold: + → Groups are tied, continue to step 4b + - Else: + → ❌ INVALID (status: "different_groups") + + b) Are the experts correct within kernel's groups? + - Compute expected_experts = top-k experts from kernel's selected groups + - If kernel_experts == expected_experts: + → ✅ VALID (status: "tied_groups") + - Else, check if differing experts have tied scores: + - Compute score_diff for differing experts + - If score_diff < expert_tie_threshold: + → ✅ VALID (status: "tied_groups") + - Else: + → ❌ INVALID (status: "tied_groups_but_wrong_experts") + +STAGE 2: OUTPUT VALUE VALIDATION +--------------------------------- +For tokens where the SAME experts were selected (status: "exact"): +- Compare kernel output values vs reference output values +- Both are normalized scores: sigmoid_scores / sum(sigmoid_scores) * scale +- Check: abs(kernel_values - ref_values) within tolerance + - If within tolerance → ✅ VALID + - Else → ❌ INVALID (value mismatch) + +For tokens where DIFFERENT experts were selected (even if acceptably): +- SKIP value validation +- Reason: Different experts → different normalization sum → different values +- The expert selection validation already confirmed correctness + +Tolerance (data-type dependent): +- bfloat16: rtol=0.1, atol=0.1 +- float16: rtol=0.05, atol=0.05 +- float32: rtol=0.01, atol=0.01 + +================================================================================ +KEY CONCEPTS +================================================================================ + +1. **Group Ties**: When two groups have similar group scores (within threshold), + selecting either group is valid. The kernel may pick a different group than + the reference due to tie-breaking. + +2. **Expert Ties**: When experts have similar biased scores (within threshold), + selecting any of them is valid. The kernel may pick different experts due + to tie-breaking. + +3. **Tied Groups → Verify Experts**: When different groups are selected due to + ties, we must still verify that the kernel selected the correct top-k experts + WITHIN its chosen groups (not compare across different groups). + +4. **Float32 Internal Computation**: The kernel computes internally in float32 + even when inputs are float16/bfloat16. The reference must match this to + ensure consistent group/expert selection. + +================================================================================ +THRESHOLDS (Data-Type Dependent) +================================================================================ + + Expert Tie Group Tie + Threshold Threshold + bfloat16: 1.0 0.05 + float16: 0.5 0.02 + float32: 0.2 0.01 + +Group thresholds are higher because group scores are sums of 2 values, +accumulating more numerical error. + +================================================================================ +""" + +import torch +import pytest +from flashinfer.dsv3_ops import NoAuxTc +# from flashinfer.utils import get_compute_capability + + +class DSv3RoutingGroundTruth: + """ + Computes and stores all ground truth data for DSv3 routing. + Performs all computations in float32 to match kernel behavior. + """ + + def __init__( + self, scores, bias, n_group, topk_group, topk, routed_scaling_factor, data_type + ): + self.num_tokens = scores.shape[0] + self.num_experts = scores.shape[1] + self.n_group = n_group + self.topk_group = topk_group + self.topk = topk + self.routed_scaling_factor = routed_scaling_factor + self.experts_per_group = self.num_experts // n_group + self.device = scores.device + + # Set thresholds based on data type + if data_type == torch.bfloat16: + self.expert_tie_threshold = 1.0 + self.group_tie_threshold = 0.05 + elif data_type == torch.float16: + self.expert_tie_threshold = 0.5 + self.group_tie_threshold = 0.02 + else: # float32 + self.expert_tie_threshold = 0.2 + self.group_tie_threshold = 0.01 + + # Convert to float32 to match kernel's internal computation + scores_f32 = scores.to(torch.float32) + bias_f32 = bias.to(torch.float32) + + # Compute sigmoid and biased scores + self.sigmoid_scores = torch.sigmoid(scores_f32) + self.biased_scores = self.sigmoid_scores + bias_f32 + + # Reshape for group-wise operations + scores_reshaped = self.biased_scores.view( + self.num_tokens, n_group, self.experts_per_group + ) + + # Compute group scores (sum of top-2 experts per group) + top2_per_group = torch.topk( + scores_reshaped, k=2, dim=-1, largest=True, sorted=True + )[0] + self.group_scores = torch.sum(top2_per_group, dim=-1) + + # Reference group selection + _, self.ref_group_indices = torch.topk( + self.group_scores, k=topk_group, dim=-1, largest=True, sorted=True + ) + + # Identify tied groups for each token + self.tied_group_sets = [] + for token_idx in range(self.num_tokens): + tied_groups = set() + group_scores_token = self.group_scores[token_idx] + + for g1 in range(n_group): + for g2 in range(g1 + 1, n_group): + score_diff = abs(group_scores_token[g1] - group_scores_token[g2]) + if score_diff < self.group_tie_threshold: + tied_groups.add(g1) + tied_groups.add(g2) + + self.tied_group_sets.append(tied_groups) + + # Compute reference expert selection and normalization + self.ref_expert_indices = torch.zeros( + self.num_tokens, topk, dtype=torch.long, device=self.device + ) + self.ref_expert_values = torch.zeros( + self.num_tokens, topk, dtype=torch.float32, device=self.device + ) + + for token_idx in range(self.num_tokens): + # Create mask for selected groups + group_mask = torch.zeros(n_group, dtype=torch.float32, device=self.device) + group_mask[self.ref_group_indices[token_idx]] = 1.0 + expert_mask = group_mask.repeat_interleave(self.experts_per_group) + + # Mask and select top-k experts + masked_biased_scores = self.biased_scores[token_idx] * expert_mask + _, topk_idx = torch.topk( + masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True + ) + + # Normalize selected experts + selected_sigmoid_scores = self.sigmoid_scores[token_idx][topk_idx] + score_sum = selected_sigmoid_scores.sum() + 1e-20 + normalized_scores = ( + selected_sigmoid_scores / score_sum * routed_scaling_factor + ) + + # Sort by normalized scores + sorted_vals, sorted_idx = torch.sort(normalized_scores, descending=True) + self.ref_expert_values[token_idx] = sorted_vals + self.ref_expert_indices[token_idx] = topk_idx[sorted_idx] + + def get_expert_group(self, expert_id): + """Return which group an expert belongs to.""" + return expert_id // self.experts_per_group + + def is_valid_group_selection(self, token_idx, selected_groups): + """Check if a set of selected groups is valid (exact match or tied).""" + ref_groups = set(self.ref_group_indices[token_idx].tolist()) + selected_groups_set = set(selected_groups) + + if selected_groups_set == ref_groups: + return True, "exact" + + if self.n_group > 1: + diff_groups = selected_groups_set.symmetric_difference(ref_groups) + tied_groups = self.tied_group_sets[token_idx] + + if diff_groups and diff_groups.issubset(tied_groups): + return True, "tied_groups" + + return False, "different_groups" + + def is_valid_expert_selection(self, token_idx, selected_experts): + """Check if a set of selected experts is valid (exact match or tied).""" + ref_experts = set(self.ref_expert_indices[token_idx].tolist()) + selected_experts_set = set(selected_experts) + + if selected_experts_set == ref_experts: + return True, "exact" + + # Check group-level validity + selected_groups = set(self.get_expert_group(e) for e in selected_experts) + ref_groups = set(self.ref_group_indices[token_idx].tolist()) + + # If different groups selected + if selected_groups != ref_groups: + is_valid_groups, group_reason = self.is_valid_group_selection( + token_idx, list(selected_groups) + ) + if not is_valid_groups: + # Groups are different and not tied - invalid + return False, group_reason + + # Groups are tied - now check if kernel selected correct top-k within its groups + expected_experts_in_kernel_groups = self._get_topk_experts_from_groups( + token_idx, list(selected_groups) + ) + + # Check if kernel's selection matches expected experts (exact or tied) + if selected_experts_set != expected_experts_in_kernel_groups: + # Different experts - check if they have tied scores + diff_experts = selected_experts_set.symmetric_difference( + expected_experts_in_kernel_groups + ) + biased_scores_token = self.biased_scores[token_idx] + diff_expert_scores = torch.tensor( + [biased_scores_token[e].item() for e in diff_experts] + ) + score_range = diff_expert_scores.max() - diff_expert_scores.min() + + if score_range >= self.expert_tie_threshold: + # Experts are wrong (not tied) - invalid even though groups are tied + return ( + False, + f"tied_groups_but_wrong_experts_score_diff={score_range:.6f}", + ) + + # Groups are tied and experts are correct (or acceptably tied) + return True, "tied_groups" + + # Same groups but different experts - check expert-level ties + diff_experts = selected_experts_set.symmetric_difference(ref_experts) + if diff_experts: + biased_scores_token = self.biased_scores[token_idx] + diff_expert_scores = torch.tensor( + [biased_scores_token[e].item() for e in diff_experts] + ) + score_range = diff_expert_scores.max() - diff_expert_scores.min() + + if score_range < self.expert_tie_threshold: + return True, "tied_experts" + else: + return ( + False, + f"score_diff={score_range:.6f}_threshold={self.expert_tie_threshold:.6f}", + ) + + return True, "exact" + + def _get_topk_experts_from_groups(self, token_idx, groups): + """ + Get the expected top-k experts from specified groups. + This computes what experts SHOULD be selected if these groups were chosen. + """ + # Create mask for specified groups + group_mask = torch.zeros(self.n_group, dtype=torch.float32, device=self.device) + for g in groups: + group_mask[g] = 1.0 + expert_mask = group_mask.repeat_interleave(self.experts_per_group) + + # Mask and select top-k experts + masked_biased_scores = self.biased_scores[token_idx] * expert_mask + _, topk_idx = torch.topk( + masked_biased_scores, k=self.topk, dim=-1, largest=True, sorted=True + ) + + return set(topk_idx.tolist()) + + +def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel): + """Validate kernel outputs and provide detailed debug info for failures.""" + num_tokens = topk_indices_kernel.shape[0] + tokens_with_different_experts = set() + + for token_idx in range(num_tokens): + kernel_experts = topk_indices_kernel[token_idx].tolist() + ref_experts = ground_truth.ref_expert_indices[token_idx].tolist() + + # Same experts - valid + if set(kernel_experts) == set(ref_experts): + continue + + # Different experts - mark for value comparison skip + tokens_with_different_experts.add(token_idx) + + # Validate the selection + is_valid, reason = ground_truth.is_valid_expert_selection( + token_idx, kernel_experts + ) + + if not is_valid: + return False, tokens_with_different_experts + + return True, tokens_with_different_experts + + +def validate_values(ground_truth, topk_values_kernel, tokens_to_skip, data_type): + """Validate that output values match reference within tolerance.""" + # Set tolerance based on data type + if data_type == torch.bfloat16: + rtol, atol = 0.1, 0.1 + elif data_type == torch.float16: + rtol, atol = 0.05, 0.05 + else: # float32 + rtol, atol = 0.01, 0.01 + + num_tokens = topk_values_kernel.shape[0] + + # Create mask for tokens to check + tokens_to_check = torch.ones(num_tokens, dtype=torch.bool) + for token_idx in tokens_to_skip: + tokens_to_check[token_idx] = False + + if not tokens_to_check.any(): + return + + # Compare values + ref_values = ground_truth.ref_expert_values[tokens_to_check].float() + kernel_values = topk_values_kernel[tokens_to_check].float() + + try: + torch.testing.assert_close( + ref_values, + kernel_values, + rtol=rtol, + atol=atol, + ) + except AssertionError: + # Find and report first mismatch + for token_idx in range(num_tokens): + if not tokens_to_check[token_idx]: + continue + + ref_vals = ground_truth.ref_expert_values[token_idx].float() + kernel_vals = topk_values_kernel[token_idx].float() + + if not torch.allclose(ref_vals, kernel_vals, rtol=rtol, atol=atol): + diff = (kernel_vals - ref_vals).abs() + max_diff = diff.max().item() + max_diff_idx = diff.argmax().item() + + print(f"\n{'=' * 80}") + print(f"VALUE MISMATCH - Token {token_idx}") + print(f"{'=' * 80}") + print(f"Tolerance: rtol={rtol}, atol={atol}") + print(f"Max difference: {max_diff:.6f} at position {max_diff_idx}") + print(f"\nReference values: {ref_vals.tolist()}") + print(f"Kernel values: {kernel_vals.tolist()}") + print(f"Absolute diff: {diff.tolist()}") + print( + f"Expert indices: {ground_truth.ref_expert_indices[token_idx].tolist()}" + ) + break + + raise + + +@pytest.mark.parametrize("num_tokens", [1, 8, 16, 64]) +@pytest.mark.parametrize("num_experts", [256, 384]) +@pytest.mark.parametrize("topk", [1, 2, 4, 8]) +@pytest.mark.parametrize("n_group", [1, 2, 4, 8]) +@pytest.mark.parametrize("topk_group", [1, 2, 4, 8]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("bias_type", [torch.float32, torch.float16, torch.bfloat16]) +def test_dsv3_fused_routing_op( + num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type +): + """ + Test NoAuxTc kernel against reference implementation. + + Validates: + 1. Expert selection equivalence (allowing for ties) + 2. Value correctness within numerical precision tolerance + """ + + # Skip invalid configurations + if topk_group * n_group < topk or topk_group > n_group: + pytest.skip( + "Invalid configuration: topk_group * n_group < topk or topk_group > n_group" + ) + if n_group > 1: + if ( + topk > 8 + or num_experts / n_group > 32 + or num_experts / n_group * topk_group > 128 + ): + pytest.skip("Invalid configuration: exceeds kernel limits for n_group > 1") + else: + if num_experts > 384 or topk > 8: + pytest.skip("Invalid configuration: exceeds kernel limits for n_group = 1") + + # Generate random inputs + torch.manual_seed(42) + scores = torch.randn(num_tokens, num_experts, device="cuda", dtype=data_type) + bias = torch.randn(num_experts, device="cuda", dtype=bias_type) + routed_scaling_factor = 1.0 + + # Compute ground truth + ground_truth = DSv3RoutingGroundTruth( + scores.clone(), + bias.clone(), + n_group, + topk_group, + topk, + routed_scaling_factor, + data_type, + ) + + # Run kernel + topk_values = torch.empty(num_tokens, topk, device="cuda", dtype=data_type) + topk_indices = torch.zeros(num_tokens, topk, device="cuda", dtype=torch.int32) + + NoAuxTc( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl=True, + ) + + # Sort kernel outputs for stable comparison + sorted_vals, sorted_idx = torch.sort(topk_values, dim=-1, descending=True) + topk_indices = topk_indices.gather(1, sorted_idx) + + # Validate expert selection + all_valid, tokens_with_different_experts = validate_expert_selection( + ground_truth, topk_indices, sorted_vals + ) + + if not all_valid: + pytest.fail("Expert selection mismatch not due to acceptable ties") + + # Validate values + validate_values(ground_truth, sorted_vals, tokens_with_different_experts, data_type)