-
Notifications
You must be signed in to change notification settings - Fork 899
fix: support fp32 logits for fp8_per_tensor and fp8_block #2534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da52d40
0c876d4
5647136
b9d4245
a934912
af075f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -131,6 +131,7 @@ class FusedMoeLauncher { | |||||||||||||||
| btg::Dtype mRoutingBiasDtype{ | ||||||||||||||||
| btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias | ||||||||||||||||
| ActivationType activation_type{ActivationType::Swiglu}; | ||||||||||||||||
| btg::Dtype mDtypeScore{btg::Dtype::Bfloat16}; | ||||||||||||||||
|
|
||||||||||||||||
| int64_t intermediate_size_factor{2}; | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -168,13 +169,19 @@ class FusedMoeLauncher { | |||||||||||||||
| int64_t weight_layout, ActivationType activation_type); | ||||||||||||||||
|
|
||||||||||||||||
| // Routing logits [num_tokens, num_experts] | ||||||||||||||||
| void check_routing_logits_shape() const { | ||||||||||||||||
| void check_routing_logits() const { | ||||||||||||||||
| if (routing_logits.has_value()) { | ||||||||||||||||
| // Check shape | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().size(0), hidden_states.size(0)) | ||||||||||||||||
| << "routing_logits and hidden_states must have the same number of tokens."; | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), args->num_experts) | ||||||||||||||||
| << "routing_logits dim1 must match num_experts."; | ||||||||||||||||
|
|
||||||||||||||||
| // Check dtype | ||||||||||||||||
| TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || | ||||||||||||||||
| routing_logits.value().dtype() == dl_bfloat16) | ||||||||||||||||
| << "routing_logits must be float or bfloat16."; | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -236,7 +243,7 @@ class FusedMoeLauncher { | |||||||||||||||
| args->local_expert_offset + args->local_num_experts <= args->num_experts) | ||||||||||||||||
| << "expert offset and count must be within valid range"; | ||||||||||||||||
|
|
||||||||||||||||
| check_routing_logits_shape(); | ||||||||||||||||
| check_routing_logits(); | ||||||||||||||||
|
|
||||||||||||||||
| if (routing_bias.has_value()) { | ||||||||||||||||
| check_routing_bias_shape(); | ||||||||||||||||
|
|
@@ -302,6 +309,19 @@ class FusedMoeLauncher { | |||||||||||||||
| workspace.cta_idx_xy_to_batch_idx = static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()); | ||||||||||||||||
| workspace.cta_idx_xy_to_mn_limit = static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()); | ||||||||||||||||
| workspace.num_non_exiting_ctas = static_cast<int*>(num_non_exiting_ctas.data_ptr()); | ||||||||||||||||
|
|
||||||||||||||||
| // Set dtype of score based on actual routing_logits dtype | ||||||||||||||||
| if (routing_logits.has_value()) { | ||||||||||||||||
| if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) | ||||||||||||||||
| << "routing_logits must be float."; | ||||||||||||||||
| mDtypeScore = btg::Dtype::Fp32; | ||||||||||||||||
| } else if (routing_logits.value().dtype() == dl_float32) { | ||||||||||||||||
| mDtypeScore = btg::Dtype::Fp32; | ||||||||||||||||
| } else { | ||||||||||||||||
| mDtypeScore = btg::Dtype::Bfloat16; | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void check_moe_common() const { | ||||||||||||||||
|
|
@@ -387,8 +407,8 @@ class FusedMoeLauncher { | |||||||||||||||
| static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||||||||||||||||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, | ||||||||||||||||
| mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||
|
|
||||||||||||||||
| check_moe(); | ||||||||||||||||
|
|
@@ -498,8 +518,9 @@ class Bf16MoeLauncher : public FusedMoeLauncher { | |||||||||||||||
| if (has_precomputed_weights) { | ||||||||||||||||
| workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr()); | ||||||||||||||||
| } else { | ||||||||||||||||
| auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; | ||||||||||||||||
| FusedMoeLauncher::expert_weights = | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); | ||||||||||||||||
|
Comment on lines
+521
to
+523
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: #!/bin/bash
set -euo pipefail
echo "== DeepSeek expW dtype in routing runner =="
rg -n -C3 'RoutingMethodType::DeepSeekV3|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu
echo
echo "== expert_weights dtype allocation sites in launcher =="
rg -n -C2 'ew_dtype|expert_weights_dtype|alloc_tensor\(\{args->num_tokens, args->top_k\}' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer Length of output: 3594 π Script executed: #!/bin/bash
set -euo pipefail
echo "== mDtypeScore initialization in launcher =="
rg -n 'mDtypeScore\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== Where expert_weights dtype flows into runner args =="
rg -n 'args->mDtypeExpW|mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== Check if Llama and Renormalize/TopK also have the issue =="
rg -n 'class.*Llama|class.*Renormalize|class.*Topk' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Allocation sites - check their context for routing method =="
rg -n -B10 'auto ew_dtype = mDtypeScore' csrc/trtllm_fused_moe_kernel_launcher.cu | head -40Repository: flashinfer-ai/flashinfer Length of output: 670 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Context: Which routing methods trigger Fp32 vs Bfloat16 mDtypeScore =="
sed -n '310,330p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Class definitions to identify which classes set mDtypeScore =="
rg -n 'class.*FusedMoe|struct.*FusedMoe' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== How is workspace.expert_weights used in runner =="
rg -n -A3 'workspace\.expert_weights' csrc/trtllm_fused_moe_runner.cu | head -30Repository: flashinfer-ai/flashinfer Length of output: 1692 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Where is args.mDtypeExpW set =="
rg -n 'args.*mDtypeExpW|mDtypeExpW.*args' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Full context of finalizeData usage in runner =="
sed -n '490,510p' csrc/trtllm_fused_moe_runner.cu
echo
echo "== What does finalizeData do with expertWeightsPtr and mDtypeExpW =="
rg -n 'finalizeData\.' csrc/trtllm_fused_moe_runner.cu | head -20
echo
echo "== Check if expert_weights buffer is actually read/interpreted as a specific dtype =="
rg -n 'expert_weights|expertWeights' csrc/trtllm_fused_moe_runner.cu | head -20Repository: flashinfer-ai/flashinfer Length of output: 102 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Runner.run signature - understand args =="
rg -n 'void run\(|struct.*Args|args->num_tokens' csrc/trtllm_fused_moe_runner.cu | head -20
echo
echo "== Check if finalizeData operations use mDtypeExpW =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_runner.cu
echo
echo "== Look for kernel calls that use expertWeightsPtr and check if dtype matters =="
rg -n 'expertWeights|expW' csrc/ -t cu | grep -E '(kernel|<<<|>>>)' | head -20Repository: flashinfer-ai/flashinfer Length of output: 112 π Script executed: #!/bin/bash
set -euo pipefandet
echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu
echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/
echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10Repository: flashinfer-ai/flashinfer Length of output: 123 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu
echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/
echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10Repository: flashinfer-ai/flashinfer Length of output: 5503 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Look for where args.mDtypeExpW is initialized =="
rg -n 'struct.*Args|args\.mDtypeExpW|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu | head -30
echo
echo "== Search in launcher for any Args struct definition =="
rg -n 'struct.*Args|\.mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu | head -30
echo
echo "== Check header files for Args struct =="
fd 'runner\.h|fused_moe.*\.h' csrc/ --exec head -150 {} \; 2>/dev/null | grep -A20 'struct Args'Repository: flashinfer-ai/flashinfer Length of output: 369 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Finalize function signature =="
rg -n 'void.*finalize|finalize.*Data' csrc/trtllm_fused_moe_runner.cu -A2 | head -30
echo
echo "== Check line 467 context for finalize function =="
sed -n '460,510p' csrc/trtllm_fused_moe_runner.cu
echo
echo "== Look for expert_weights dtype handling in finalize kernel =="
fd '\.h$' csrc/ --exec rg -l 'finalize' {} \; 2>/dev/null | head -5 | xargs rg -n 'struct.*finalize|mDtypeExpW' 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer Length of output: 4018 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Find MoERunnerArgs struct definition =="
fd '\.h$' csrc/ | xargs rg -l 'MoERunnerArgs' | head -3
echo
echo "== Search for mDtypeExpW in those headers =="
fd '\.h$' csrc/ | xargs rg -n 'mDtypeExpW|struct MoERunnerArgs' 2>/dev/null | head -40
echo
echo "== Check if launcher sets args.mDtypeExpW anywhere =="
rg -n 'args\.mDtypeExpW|mDtypeExpW.*=' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer Length of output: 112 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Search for MoERunnerArgs definition =="
rg -n 'struct MoERunnerArgs|mDtypeExpW' csrc/ -t h 2>/dev/null | head -50
echo
echo "== Look at Bf16MoeLauncher class (lines 458+) =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer Length of output: 109 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Look at actual class that creates workspace and calls runner =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Search for where Runner::run is called from launcher =="
rg -n 'Runner::run|\.run\(' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== Check the workspace initialization to understand args flow =="
rg -n 'workspace\.|MoEWorkspace' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20Repository: flashinfer-ai/flashinfer Length of output: 6457 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Check the routing_runner.run call signature and what args it receives =="
sed -n '398,410p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Look for args.mDtypeExpW initialization anywhere in launcher =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Check what happens with expert_weights in routing kernel (DeepSeekV3) =="
rg -n 'mPtrTopKWeights|mDtypeExpW' csrc/ -t cu 2>/dev/null | grep -i deepseek
echo
echo "== Verify: is args.mDtypeExpW ever set in launcher classes? =="
rg -n 'args->mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer Length of output: 1120 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Full routing_runner.run call =="
sed -n '398,415p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Search for where args is created/initialized =="
rg -n 'args.*=' csrc/trtllm_fused_moe_kernel_launcher.cu | grep -E 'mDtypeExpW|make_unique|new' | head -20
echo
echo "== Check if expert_weights buffer is actually used by routing kernel =="
rg -n 'mPtrTopKWeights' csrc/ -A5 -B5 2>/dev/null | head -40Repository: flashinfer-ai/flashinfer Length of output: 4895 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Check if mDtypeExpW is actually used in the routing kernel =="
fd '\.h$' csrc/ | xargs rg -l 'mDtypeExpW|routingDeepSeek' 2>/dev/null | head -5
echo
echo "== Check DeepSeekV3 routing kernel implementation =="
rg -n 'routingDeepSeek|mDtypeExpW|mPtrTopKWeights' csrc/ -t h -t cu 2>/dev/null | grep -A5 -B5 'mPtrTopKWeights' | head -50
echo
echo "== Verify: args.mDtypeExpW default value - check if it's initialized =="
rg -n 'MoERunnerArgs' csrc/ -t h --max-count=5 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer Length of output: 134 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Look for finalize kernel that uses mDtypeExpW and expertWeightsPtr =="
find csrc -name "*.h" -o -name "*.cu" | xargs rg -l 'finalize.*Data|expertWeightsPtr' 2>/dev/null | head -5
echo
echo "== Check how finalize uses expert_weights dtype =="
rg -n 'expertWeightsPtr|mDtypeExpW' csrc/ 2>/dev/null | grep -v 'Binary' | head -30
echo
echo "== Check the actual MoERunnerArgs initialization - look at constructor/default values =="
find csrc -name "*.h" | xargs grep -A30 'struct MoERunnerArgs\|class MoERunnerArgs' 2>/dev/null | head -60Repository: flashinfer-ai/flashinfer Length of output: 1675 π Script executed: #!/bin/bash
set -euo pipefail
echo "== Check finalize kernel - how TypeExpW is defined and used =="
sed -n '660,680p' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu
echo
echo "== Check if TypeExpW is template parameter or determined by mDtypeExpW =="
rg -n 'template.*TypeExpW|TypeExpW|mDtypeExpW' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu | head -30
echo
echo "== Find where args.mDtypeExpW is actually set/initialized =="
rg -n 'mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer Length of output: 1771 Expert_weights dtype mismatch between launcher allocation and kernel interpretation affects all routing methods. The launcher allocates This issue is broader than DeepSeekV3 alone: it affects any configuration where π§ Suggested fix (centralize expW dtype policy) class FusedMoeLauncher {
protected:
+ DLDataType get_expert_weights_dtype() const {
+ if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
+ // Runner DeepSeek path currently expects expW as BF16.
+ return dl_bfloat16;
+ }
+ return mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+ }- auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+ auto ew_dtype = get_expert_weights_dtype();Apply the replacement at all four allocation sites (lines 521β523, 662β664, 938β940, 1213β1215). π Committable suggestion
Suggested change
π€ Prompt for AI Agents |
||||||||||||||||
| workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr(); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
@@ -638,8 +659,9 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { | |||||||||||||||
| routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; | ||||||||||||||||
| mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; | ||||||||||||||||
|
|
||||||||||||||||
| auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; | ||||||||||||||||
| expert_weights = | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device()); | ||||||||||||||||
|
|
||||||||||||||||
| workspace.expert_weights = expert_weights.data_ptr(); | ||||||||||||||||
| if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) { | ||||||||||||||||
|
|
@@ -913,9 +935,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||
| // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr | ||||||||||||||||
| bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0; | ||||||||||||||||
| if (!has_precomputed_weights) { | ||||||||||||||||
| // Allocate expert_weights buffer for routing output | ||||||||||||||||
| auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; | ||||||||||||||||
| FusedMoeLauncher::expert_weights = | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); | ||||||||||||||||
| workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr(); | ||||||||||||||||
| } else { | ||||||||||||||||
| workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr()); | ||||||||||||||||
|
|
@@ -1092,8 +1114,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||
| static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||||||||||||||||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, | ||||||||||||||||
| mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||
|
|
||||||||||||||||
| check_moe(); | ||||||||||||||||
|
|
@@ -1188,8 +1210,9 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||
| routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; | ||||||||||||||||
| mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; | ||||||||||||||||
|
|
||||||||||||||||
| auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; | ||||||||||||||||
| expert_weights = | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | ||||||||||||||||
| alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device()); | ||||||||||||||||
|
|
||||||||||||||||
| workspace.expert_weights = expert_weights.data_ptr(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
@@ -1555,8 +1578,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||
| static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||||||||||||||||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, | ||||||||||||||||
| mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||
|
|
||||||||||||||||
| check_moe(); | ||||||||||||||||
|
|
@@ -1694,13 +1717,12 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe( | |||||||||||||||
| // Basic type validation | ||||||||||||||||
| auto dtype = hidden_states.dtype(); | ||||||||||||||||
| auto activation = static_cast<ActivationType>(activation_type); | ||||||||||||||||
| if (use_routing_scales_on_input) { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; | ||||||||||||||||
| } else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; | ||||||||||||||||
| } else { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; | ||||||||||||||||
|
|
||||||||||||||||
| if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) | ||||||||||||||||
| << "routing_logits must be float for DeepSeekV3."; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16) | ||||||||||||||||
| << "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) | ||||||||||||||||
|
|
@@ -1799,9 +1821,6 @@ Array<Tensor> trtllm_fp8_block_scale_moe( | |||||||||||||||
| if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) | ||||||||||||||||
| << "routing_logits must be float."; | ||||||||||||||||
| } else { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16) | ||||||||||||||||
| << "routing_logits must be bfloat16."; | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
| TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) | ||||||||||||||||
|
|
@@ -1930,18 +1949,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe( | |||||||||||||||
| << "unsupported weight_scale_vec_size."; | ||||||||||||||||
| auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1; | ||||||||||||||||
|
|
||||||||||||||||
| if (routing_logits.has_value()) { | ||||||||||||||||
| TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || | ||||||||||||||||
| routing_logits.value().dtype() == dl_bfloat16) | ||||||||||||||||
| << "routing_logits must be float or bfloat16."; | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) | ||||||||||||||||
| << "routing_logits has incorrect shape."; | ||||||||||||||||
| if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { | ||||||||||||||||
| TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) | ||||||||||||||||
| << "routing_logits must be float."; | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
| if (routing_bias.has_value()) { | ||||||||||||||||
| TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || | ||||||||||||||||
| routing_bias.value().dtype() == dl_float32) | ||||||||||||||||
|
|
||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.