Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 41 additions & 34 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class FusedMoeLauncher {
btg::Dtype mRoutingBiasDtype{
btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias
ActivationType activation_type{ActivationType::Swiglu};
btg::Dtype mDtypeScore{btg::Dtype::Bfloat16};

int64_t intermediate_size_factor{2};

Expand Down Expand Up @@ -168,13 +169,19 @@ class FusedMoeLauncher {
int64_t weight_layout, ActivationType activation_type);

// Routing logits [num_tokens, num_experts]
void check_routing_logits_shape() const {
void check_routing_logits() const {
if (routing_logits.has_value()) {
// Check shape
TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D.";
TVM_FFI_ICHECK_EQ(routing_logits.value().size(0), hidden_states.size(0))
<< "routing_logits and hidden_states must have the same number of tokens.";
TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), args->num_experts)
<< "routing_logits dim1 must match num_experts.";

// Check dtype
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
routing_logits.value().dtype() == dl_bfloat16)
<< "routing_logits must be float or bfloat16.";
}
}

Expand Down Expand Up @@ -236,7 +243,7 @@ class FusedMoeLauncher {
args->local_expert_offset + args->local_num_experts <= args->num_experts)
<< "expert offset and count must be within valid range";

check_routing_logits_shape();
check_routing_logits();

if (routing_bias.has_value()) {
check_routing_bias_shape();
Expand Down Expand Up @@ -302,6 +309,19 @@ class FusedMoeLauncher {
workspace.cta_idx_xy_to_batch_idx = static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr());
workspace.cta_idx_xy_to_mn_limit = static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr());
workspace.num_non_exiting_ctas = static_cast<int*>(num_non_exiting_ctas.data_ptr());

// Set dtype of score based on actual routing_logits dtype
if (routing_logits.has_value()) {
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
<< "routing_logits must be float.";
mDtypeScore = btg::Dtype::Fp32;
} else if (routing_logits.value().dtype() == dl_float32) {
mDtypeScore = btg::Dtype::Fp32;
} else {
mDtypeScore = btg::Dtype::Bfloat16;
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

void check_moe_common() const {
Expand Down Expand Up @@ -387,8 +407,8 @@ class FusedMoeLauncher {
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
Expand Down Expand Up @@ -498,8 +518,9 @@ class Bf16MoeLauncher : public FusedMoeLauncher {
if (has_precomputed_weights) {
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
} else {
auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
Comment on lines +521 to +523
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== DeepSeek expW dtype in routing runner =="
rg -n -C3 'RoutingMethodType::DeepSeekV3|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu

echo
echo "== expert_weights dtype allocation sites in launcher =="
rg -n -C2 'ew_dtype|expert_weights_dtype|alloc_tensor\(\{args->num_tokens, args->top_k\}' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 3594


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== mDtypeScore initialization in launcher =="
rg -n 'mDtypeScore\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== Where expert_weights dtype flows into runner args =="
rg -n 'args->mDtypeExpW|mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== Check if Llama and Renormalize/TopK also have the issue =="
rg -n 'class.*Llama|class.*Renormalize|class.*Topk' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Allocation sites - check their context for routing method =="
rg -n -B10 'auto ew_dtype = mDtypeScore' csrc/trtllm_fused_moe_kernel_launcher.cu | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 670


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Context: Which routing methods trigger Fp32 vs Bfloat16 mDtypeScore =="
sed -n '310,330p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Class definitions to identify which classes set mDtypeScore =="
rg -n 'class.*FusedMoe|struct.*FusedMoe' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== How is workspace.expert_weights used in runner =="
rg -n -A3 'workspace\.expert_weights' csrc/trtllm_fused_moe_runner.cu | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1692


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Where is args.mDtypeExpW set =="
rg -n 'args.*mDtypeExpW|mDtypeExpW.*args' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Full context of finalizeData usage in runner =="
sed -n '490,510p' csrc/trtllm_fused_moe_runner.cu

echo
echo "== What does finalizeData do with expertWeightsPtr and mDtypeExpW =="
rg -n 'finalizeData\.' csrc/trtllm_fused_moe_runner.cu | head -20

echo
echo "== Check if expert_weights buffer is actually read/interpreted as a specific dtype =="
rg -n 'expert_weights|expertWeights' csrc/trtllm_fused_moe_runner.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 102


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Runner.run signature - understand args =="
rg -n 'void run\(|struct.*Args|args->num_tokens' csrc/trtllm_fused_moe_runner.cu | head -20

echo
echo "== Check if finalizeData operations use mDtypeExpW =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_runner.cu

echo
echo "== Look for kernel calls that use expertWeightsPtr and check if dtype matters =="
rg -n 'expertWeights|expW' csrc/ -t cu | grep -E '(kernel|<<<|>>>)' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

#!/bin/bash
set -euo pipefandet

echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu

echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/

echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 123


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu

echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/

echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 5503


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Look for where args.mDtypeExpW is initialized =="
rg -n 'struct.*Args|args\.mDtypeExpW|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu | head -30

echo
echo "== Search in launcher for any Args struct definition =="
rg -n 'struct.*Args|\.mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu | head -30

echo
echo "== Check header files for Args struct =="
fd 'runner\.h|fused_moe.*\.h' csrc/ --exec head -150 {} \; 2>/dev/null | grep -A20 'struct Args'

Repository: flashinfer-ai/flashinfer

Length of output: 369


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Finalize function signature =="
rg -n 'void.*finalize|finalize.*Data' csrc/trtllm_fused_moe_runner.cu -A2 | head -30

echo
echo "== Check line 467 context for finalize function =="
sed -n '460,510p' csrc/trtllm_fused_moe_runner.cu

echo
echo "== Look for expert_weights dtype handling in finalize kernel =="
fd '\.h$' csrc/ --exec rg -l 'finalize' {} \; 2>/dev/null | head -5 | xargs rg -n 'struct.*finalize|mDtypeExpW' 2>/dev/null | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 4018


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Find MoERunnerArgs struct definition =="
fd '\.h$' csrc/ | xargs rg -l 'MoERunnerArgs' | head -3

echo
echo "== Search for mDtypeExpW in those headers =="
fd '\.h$' csrc/ | xargs rg -n 'mDtypeExpW|struct MoERunnerArgs' 2>/dev/null | head -40

echo
echo "== Check if launcher sets args.mDtypeExpW anywhere =="
rg -n 'args\.mDtypeExpW|mDtypeExpW.*=' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Search for MoERunnerArgs definition =="
rg -n 'struct MoERunnerArgs|mDtypeExpW' csrc/ -t h 2>/dev/null | head -50

echo
echo "== Look at Bf16MoeLauncher class (lines 458+) =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 109


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Look at actual class that creates workspace and calls runner =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Search for where Runner::run is called from launcher =="
rg -n 'Runner::run|\.run\(' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== Check the workspace initialization to understand args flow =="
rg -n 'workspace\.|MoEWorkspace' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 6457


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Check the routing_runner.run call signature and what args it receives =="
sed -n '398,410p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Look for args.mDtypeExpW initialization anywhere in launcher =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Check what happens with expert_weights in routing kernel (DeepSeekV3) =="
rg -n 'mPtrTopKWeights|mDtypeExpW' csrc/ -t cu 2>/dev/null | grep -i deepseek

echo
echo "== Verify: is args.mDtypeExpW ever set in launcher classes? =="
rg -n 'args->mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1120


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Full routing_runner.run call =="
sed -n '398,415p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Search for where args is created/initialized =="
rg -n 'args.*=' csrc/trtllm_fused_moe_kernel_launcher.cu | grep -E 'mDtypeExpW|make_unique|new' | head -20

echo
echo "== Check if expert_weights buffer is actually used by routing kernel =="
rg -n 'mPtrTopKWeights' csrc/ -A5 -B5 2>/dev/null | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 4895


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Check if mDtypeExpW is actually used in the routing kernel =="
fd '\.h$' csrc/ | xargs rg -l 'mDtypeExpW|routingDeepSeek' 2>/dev/null | head -5

echo
echo "== Check DeepSeekV3 routing kernel implementation =="
rg -n 'routingDeepSeek|mDtypeExpW|mPtrTopKWeights' csrc/ -t h -t cu 2>/dev/null | grep -A5 -B5 'mPtrTopKWeights' | head -50

echo
echo "== Verify: args.mDtypeExpW default value - check if it's initialized =="
rg -n 'MoERunnerArgs' csrc/ -t h --max-count=5 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 134


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Look for finalize kernel that uses mDtypeExpW and expertWeightsPtr =="
find csrc -name "*.h" -o -name "*.cu" | xargs rg -l 'finalize.*Data|expertWeightsPtr' 2>/dev/null | head -5

echo
echo "== Check how finalize uses expert_weights dtype =="
rg -n 'expertWeightsPtr|mDtypeExpW' csrc/ 2>/dev/null | grep -v 'Binary' | head -30

echo
echo "== Check the actual MoERunnerArgs initialization - look at constructor/default values =="
find csrc -name "*.h" | xargs grep -A30 'struct MoERunnerArgs\|class MoERunnerArgs' 2>/dev/null | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1675


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Check finalize kernel - how TypeExpW is defined and used =="
sed -n '660,680p' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu

echo
echo "== Check if TypeExpW is template parameter or determined by mDtypeExpW =="
rg -n 'template.*TypeExpW|TypeExpW|mDtypeExpW' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu | head -30

echo
echo "== Find where args.mDtypeExpW is actually set/initialized =="
rg -n 'mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1771


Expert_weights dtype mismatch between launcher allocation and kernel interpretation affects all routing methods.

The launcher allocates expert_weights with dtype determined by mDtypeScore (which is Fp32 for DeepSeekV3 and conditionally Fp32 for other methods). However, all routing pathsβ€”DeepSeekV3, Llama4, and TopKβ€”hardcode mDtypeExpW = Bfloat16 independent of the launcher's allocation. This buffer is then passed to the finalize kernel, which reads it using TypeExpW template instantiated from mDtypeExpW. When expert_weights is allocated as Fp32 but read as Bfloat16, bytes are misinterpreted, causing data corruption.

This issue is broader than DeepSeekV3 alone: it affects any configuration where mDtypeScore is Fp32 across all routing methods. The launcher never communicates the actual expert_weights dtype back to the runner or finalize kernel.

πŸ”§ Suggested fix (centralize expW dtype policy)
 class FusedMoeLauncher {
  protected:
+  DLDataType get_expert_weights_dtype() const {
+    if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
+      // Runner DeepSeek path currently expects expW as BF16.
+      return dl_bfloat16;
+    }
+    return mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+  }
-      auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+      auto ew_dtype = get_expert_weights_dtype();

Apply the replacement at all four allocation sites (lines 521–523, 662–664, 938–940, 1213–1215).

πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
auto ew_dtype = get_expert_weights_dtype();
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 521 - 523, The
expert_weights buffer is allocated using mDtypeScore but always read by the
finalize kernel as TypeExpW instantiated from mDtypeExpW, causing mismatched
interpretation; fix by centralizing the expert-weights dtype policy: derive a
single expW_dtype (based on mDtypeExpW) and use that when calling alloc_tensor
to set FusedMoeLauncher::expert_weights at all allocation sites (the ones
allocating expert_weights), and ensure the same expW_dtype is passed/visible to
the runner/finalize kernel invocation so the template TypeExpW and the allocated
buffer use the same dtype.

workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr();
}
}
Expand Down Expand Up @@ -638,8 +659,9 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;

auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device());

workspace.expert_weights = expert_weights.data_ptr();
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
Expand Down Expand Up @@ -913,9 +935,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0;
if (!has_precomputed_weights) {
// Allocate expert_weights buffer for routing output
auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr();
} else {
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
Expand Down Expand Up @@ -1092,8 +1114,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
Expand Down Expand Up @@ -1188,8 +1210,9 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;

auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device());

workspace.expert_weights = expert_weights.data_ptr();
}
Expand Down Expand Up @@ -1555,8 +1578,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
Expand Down Expand Up @@ -1694,13 +1717,12 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
// Basic type validation
auto dtype = hidden_states.dtype();
auto activation = static_cast<ActivationType>(activation_type);
if (use_routing_scales_on_input) {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float.";
} else {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";

if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32)
<< "routing_logits must be float for DeepSeekV3.";
}

TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16)
<< "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16.";
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn)
Expand Down Expand Up @@ -1799,9 +1821,6 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
<< "routing_logits must be float.";
} else {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16)
<< "routing_logits must be bfloat16.";
}
}
TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn)
Expand Down Expand Up @@ -1930,18 +1949,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
<< "unsupported weight_scale_vec_size.";
auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1;

if (routing_logits.has_value()) {
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
routing_logits.value().dtype() == dl_bfloat16)
<< "routing_logits must be float or bfloat16.";
TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D.";
TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts)
<< "routing_logits has incorrect shape.";
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
<< "routing_logits must be float.";
}
}
if (routing_bias.has_value()) {
TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 ||
routing_bias.value().dtype() == dl_float32)
Expand Down
6 changes: 4 additions & 2 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
int32_t* numNonExitingCtas, btg::Dtype dtypeScore, btg::Dtype dtypeElt,
btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22");
Expand Down Expand Up @@ -140,6 +140,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
//

routingData.mDtypeExpW = btg::Dtype::Bfloat16;
routingData.mDtypeScore = dtypeScore;

// routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input
routingData.mUsePdl = true;
routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive;
Expand Down
33 changes: 28 additions & 5 deletions include/flashinfer/trtllm/fused_moe/DevKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,45 @@ namespace moe::dev {

#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag1, numExperts, numTopExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \
if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Fp32 && \
!extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \
extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, \
LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \
!extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, \
LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32 && \
extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, \
LAUNCH_ESC(__nv_bfloat16, float, numExperts, numTopExperts, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32 && \
!extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, \
LAUNCH_ESC(__nv_bfloat16, float, numExperts, numTopExperts, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \
extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, \
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \
!extraFlag1) { \
LAUNCH_TILEN(data, coopLaunch, \
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
FLASHINFER_WARN("Unsupported combination of mDtypeScore and mDtypeExpW"); \
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
1 change: 1 addition & 0 deletions include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ namespace routingRenormalize {

struct Data : public DataBase {
tg::Dtype mDtypeExpW{tg::Dtype::Fp32};
tg::Dtype mDtypeScore{tg::Dtype::Fp32};
tg::Dtype mDtypeElt{tg::Dtype::Bfloat16};

bool mDoSoftmaxBeforeTopK{false};
Expand Down
6 changes: 3 additions & 3 deletions include/flashinfer/trtllm/fused_moe/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ class Runner {
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas,
batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType,
cudaStream_t stream);
batchedGemm::trtllm::gen::Dtype dtypeScore, batchedGemm::trtllm::gen::Dtype dtypeElt,
batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream);

private:
int32_t mTileTokensDim{8};
Expand Down
1 change: 1 addition & 0 deletions tests/moe/test_dpsk_fused_moe_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def __init__(self):
activation_type=ActivationType.Swiglu,
num_tokens=seq_len,
hidden_size=7168, # DeepSeek-V3 hidden size
logits_dtype=torch.float32,
intermediate_size=intermediate_size,
)

Expand Down
Loading
Loading