@@ -72,6 +72,8 @@ class DtypeUtils {
7272 default :
7373 TVM_FFI_ICHECK (false ) << " unsupported data type" ;
7474 }
75+
76+ return nvinfer1::DataType::kFLOAT ; // supress compiler warning
7577 }
7678
7779 private:
@@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
111113 TVM_FFI_ICHECK (false ) << " Invalid output type " << DLDataTypeToString (output_type)
112114 << " specified for " << DLDataTypeToString (mActivationDtype );
113115 }
116+
117+ return nullptr ; // supress compiler warning
114118 };
115119
116120 FusedMoeRunner (DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype,
@@ -219,7 +223,13 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
219223 }
220224
221225 mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
222- mAllProfiles = mKernelRunner ->getTactics ();
226+ // Get tactics for both GEMM1 and GEMM2, combine them
227+ auto gemm1_tactics = mKernelRunner ->getTactics (kernels::MoeGemmId::GEMM_1);
228+ auto gemm2_tactics = mKernelRunner ->getTactics (kernels::MoeGemmId::GEMM_2);
229+ mGemm1TacticCount = static_cast <int64_t >(gemm1_tactics.size ());
230+ mGemm2TacticCount = static_cast <int64_t >(gemm2_tactics.size ());
231+ mAllProfiles = gemm1_tactics;
232+ mAllProfiles .insert (mAllProfiles .end (), gemm2_tactics.begin (), gemm2_tactics.end ());
223233 TVM_FFI_ICHECK (!mAllProfiles .empty ())
224234 << " No valid tactics available for fused moe op with the requested input combination "
225235 " Activation: "
@@ -367,27 +377,31 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
367377
368378 // TODO: support lora in the future
369379 ::tensorrt_llm::kernels::LoraParams lora_params{};
380+ // HACK Define default values for parameters we don't have good values for
381+ bool const swizzled_input_sf = true ; // Assume input_sf is swizzled by default
382+ int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default
383+ bool const use_lora = false ; // No lora support yet
370384#ifdef USING_OSS_CUTLASS_MOE_GEMM
371385 mKernelRunner ->runMoe (
372386 input.data_ptr (), input_sf.has_value () ? input_sf.value ().data_ptr () : nullptr ,
373- reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
387+ swizzled_input_sf, reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
374388 token_final_scales.has_value ()
375389 ? reinterpret_cast <float const *>(token_final_scales.value ().data_ptr ())
376390 : nullptr ,
377391 fc1_expert_weights.data_ptr (),
378392 fc1_expert_biases.has_value () ? fc1_expert_biases.value ().data_ptr () : nullptr ,
379393 activation_params, fc2_expert_weights.data_ptr (),
380394 fc2_expert_biases.has_value () ? fc2_expert_biases.value ().data_ptr () : nullptr ,
381- quant_params, num_rows, hidden_size, inter_size, num_experts_total,
395+ quant_params, num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total,
382396 static_cast <int >(experts_per_token),
383397 static_cast <char *>(workspace_info.workspace .data_ptr ()), output.data_ptr (),
384398 static_cast <int *>(workspace_info.src_to_dest_map ), parallelism_config, enable_alltoall,
385- false , lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params,
399+ use_lora , lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params,
386400 enable_pdl, stream);
387401#else
388402 mKernelRunner ->runMoe (
389403 input.data_ptr (), input_sf.has_value () ? input_sf.value ().data_ptr () : nullptr ,
390- reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
404+ swizzled_input_sf, reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
391405 token_final_scales.has_value ()
392406 ? reinterpret_cast <float const *>(token_final_scales.value ().data_ptr ())
393407 : nullptr ,
@@ -396,10 +410,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
396410 activation_params, fc2_expert_weights.data_ptr (),
397411 fc2_expert_biases.has_value () ? fc2_expert_biases.value ().data_ptr () : nullptr ,
398412 quant_params, num_rows, hidden_size, inter_size, num_experts_total,
399- static_cast <int >(experts_per_token), static_cast < char *>(workspace_info. workspace ),
400- output. data_ptr (), static_cast <int *>(workspace_info.src_to_dest_map ), parallelism_config ,
401- false , lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params ,
402- enable_pdl, stream);
413+ static_cast <int >(experts_per_token),
414+ static_cast <char *>(workspace_info.workspace . data_ptr ()), output. data_ptr () ,
415+ static_cast < int *>(workspace_info. src_to_dest_map ), parallelism_config, false , lora_params ,
416+ mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params, enable_pdl, stream);
403417#endif
404418 }
405419
@@ -547,39 +561,44 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
547561
548562 // TODO: support lora in the future
549563 ::tensorrt_llm::kernels::LoraParams lora_params{};
564+ // HACK Define default values for parameters we don't have good values for
565+ bool const swizzled_input_sf_ml = true ; // Assume input_sf is swizzled by default
566+ int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default
567+ bool const use_lora_ml = false ; // No lora support yet
550568#ifdef USING_OSS_CUTLASS_MOE_GEMM
551569 mKernelRunner ->runMoe (
552570 input.data_ptr (), input_sf.has_value () ? input_sf.value ().data_ptr () : nullptr ,
553- reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
571+ swizzled_input_sf_ml, reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
554572 token_final_scales.has_value ()
555573 ? reinterpret_cast <float const *>(token_final_scales.value ().data_ptr ())
556574 : nullptr ,
557575 fc1_expert_weights.data_ptr (),
558576 fc1_expert_biases.has_value () ? fc1_expert_biases.value ().data_ptr () : nullptr ,
559577 activation_params, fc2_expert_weights.data_ptr (),
560578 fc2_expert_biases.has_value () ? fc2_expert_biases.value ().data_ptr () : nullptr ,
561- quant_params, num_rows, hidden_size, inter_size, num_experts_total,
579+ quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
562580 static_cast <int >(experts_per_token),
563581 static_cast <char *>(workspace_info.workspace .data_ptr ()), output.data_ptr (),
564582 static_cast <int *>(workspace_info.src_to_dest_map ), parallelism_config, enable_alltoall,
565- false , lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params,
583+ use_lora_ml , lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params,
566584 enable_pdl, stream);
567585#else
568586 mKernelRunner ->runMoe (
569587 input.data_ptr (), input_sf.has_value () ? input_sf.value ().data_ptr () : nullptr ,
570- reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
588+ swizzled_input_sf_ml, reinterpret_cast <int const *>(token_selected_experts.data_ptr ()),
571589 token_final_scales.has_value ()
572590 ? reinterpret_cast <float const *>(token_final_scales.value ().data_ptr ())
573591 : nullptr ,
574592 fc1_expert_weights.data_ptr (),
575593 fc1_expert_biases.has_value () ? fc1_expert_biases.value ().data_ptr () : nullptr ,
576594 activation_params, fc2_expert_weights.data_ptr (),
577595 fc2_expert_biases.has_value () ? fc2_expert_biases.value ().data_ptr () : nullptr ,
578- quant_params, num_rows, hidden_size, inter_size, num_experts_total,
579- static_cast <int >(experts_per_token), static_cast <char *>(workspace_info.workspace ),
580- output.data_ptr (), static_cast <int *>(workspace_info.src_to_dest_map ), parallelism_config,
581- false , lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params,
582- enable_pdl, stream);
596+ quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
597+ static_cast <int >(experts_per_token),
598+ static_cast <char *>(workspace_info.workspace .data_ptr ()), output.data_ptr (),
599+ static_cast <int *>(workspace_info.src_to_dest_map ), parallelism_config, false , use_lora_ml,
600+ lora_params, mUseDeepSeekFP8BlockScaling , min_latency_mode, min_latency_params, enable_pdl,
601+ stream);
583602#endif
584603 }
585604
@@ -641,19 +660,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
641660 auto activation_dtype =
642661 (mUseW4GroupScaling && !isWFP4A16Quant ()) ? dl_float8_e4m3fn : mActivationDtype ;
643662 activation_dtype = isNvfp4Quant () ? dl_int64 : activation_dtype;
663+ int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default
644664#ifdef USING_OSS_CUTLASS_MOE_GEMM
645665 mProfiler ->init (*mKernelRunner .get (), mProfiler ->mGemmToProfile ,
646666 DtypeUtils::dataType (activation_dtype), DtypeUtils::dataType (mWeightDtype ),
647667 DtypeUtils::dataType (mOutputDtype ), num_experts, static_cast <int >(top_k),
648- hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA ,
649- min_latency_mode,
668+ hidden_size, unpadded_hidden_size_profiler, inter_size, group_size ,
669+ activation_type, USE_BIAS, USE_LORA, min_latency_mode,
650670 /* need_weights*/ false , parallelism_config, enable_alltoall);
651671#else
652672 mProfiler ->init (*mKernelRunner .get (), mProfiler ->mGemmToProfile ,
653673 DtypeUtils::dataType (activation_dtype), DtypeUtils::dataType (mWeightDtype ),
654674 DtypeUtils::dataType (mOutputDtype ), num_experts, static_cast <int >(top_k),
655- hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA ,
656- min_latency_mode,
675+ hidden_size, unpadded_hidden_size_profiler, inter_size, group_size ,
676+ activation_type, USE_BIAS, USE_LORA, min_latency_mode,
657677 /* need_weights*/ false , parallelism_config);
658678#endif
659679
@@ -691,6 +711,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
691711 });
692712 } else if (name == " get_tactic_num" ) {
693713 return Function::FromTyped ([this ]() -> int64_t { return getTacticNum (); });
714+ } else if (name == " get_gemm1_tactic_count" ) {
715+ return Function::FromTyped ([this ]() -> int64_t { return mGemm1TacticCount ; });
716+ } else if (name == " get_gemm2_tactic_count" ) {
717+ return Function::FromTyped ([this ]() -> int64_t { return mGemm2TacticCount ; });
694718 } else if (name == " run_moe" ) {
695719 return Function::FromTyped (
696720 [this ](TensorView output, TensorView input, TensorView token_selected_experts,
@@ -758,6 +782,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
758782
759783 using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
760784 std::vector<Profile> mAllProfiles ;
785+ int64_t mGemm1TacticCount {0 };
786+ int64_t mGemm2TacticCount {0 };
761787
762788 void setRunnerProfiles (Optional<Array<int64_t >> profile_ids) {
763789 if (mUseDeepSeekFP8BlockScaling ) {
@@ -771,13 +797,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
771797 }
772798
773799 auto best_gemm1_profile = mAllProfiles .front ();
774- auto best_gemm2_profile = mAllProfiles .front ();
800+ // Default GEMM2 profile should come from the GEMM2 subrange if present
801+ auto best_gemm2_profile =
802+ (mGemm2TacticCount > 0 && mAllProfiles .size () > static_cast <size_t >(mGemm1TacticCount ))
803+ ? mAllProfiles .at (mGemm1TacticCount )
804+ : mAllProfiles .front ();
775805 if (profile_ids.has_value ()) {
776806 TVM_FFI_ICHECK_EQ (profile_ids.value ().size (), 2 ) << " Expecting 2 profile ids" ;
777- best_gemm1_profile = profile_ids.value ()[0 ] == -1 ? best_gemm1_profile
778- : mAllProfiles .at (profile_ids.value ()[0 ]);
779- best_gemm2_profile = profile_ids.value ()[1 ] == -1 ? best_gemm2_profile
780- : mAllProfiles .at (profile_ids.value ()[1 ]);
807+ // GEMM1 index: accept absolute index; otherwise if clearly out of combined range, keep
808+ // default
809+ auto id1 = profile_ids.value ()[0 ];
810+ if (id1 != -1 ) {
811+ TVM_FFI_ICHECK (id1 >= 0 && id1 < mGemm1TacticCount ) << " Invalid gemm1 profile id: " << id1;
812+ best_gemm1_profile = mAllProfiles .at (id1);
813+ }
814+
815+ // GEMM2 index: support both absolute (combined) and relative (within GEMM2 subrange) ids
816+ auto id2 = profile_ids.value ()[1 ];
817+ if (id2 != -1 ) {
818+ int64_t absolute_id2 = id2;
819+ // If id2 appears relative to GEMM2 subrange, offset it
820+ if (id2 >= 0 && id2 < mGemm2TacticCount ) {
821+ absolute_id2 = mGemm1TacticCount + id2;
822+ }
823+ TVM_FFI_ICHECK (absolute_id2 >= 0 &&
824+ absolute_id2 < static_cast <int64_t >(mAllProfiles .size ()))
825+ << " Invalid gemm2 profile id: " << id2;
826+ best_gemm2_profile = mAllProfiles .at (absolute_id2);
827+ }
781828 }
782829 mKernelRunner ->setTactic (best_gemm1_profile, best_gemm2_profile);
783830 }
0 commit comments