@@ -4563,10 +4563,14 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
45634563 return ;
45644564 }
45654565
4566+ bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4 );
4567+ bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16 )
4568+ && mWType == nvinfer1::DataType::kUINT8 );
4569+ bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
45664570 bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
4567- if (use_finalize_fusion
4568- && (! mInterface -> use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
4569- || mGemmToProfile != GemmToProfile::GEMM_2) )
4571+ bool const finalize_fusion_not_supported = ! mInterface -> use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
4572+ || mGemmToProfile != GemmToProfile::GEMM_2;
4573+ if (use_finalize_fusion && finalize_fusion_not_supported )
45704574 {
45714575 return ;
45724576 }
@@ -4624,11 +4628,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
46244628 /* GEMM1 */
46254629 gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
46264630 gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
4627-
4628- bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4 );
4629- bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16 )
4630- && mWType == nvinfer1::DataType::kUINT8 );
4631- bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
46324631 if (use_finalize_fusion)
46334632 {
46344633 assert (!mMinLatencyMode );
@@ -4740,6 +4739,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
47404739 {
47414740 tma_ws_input_template = mTmaInputCache [mSampleIndex ][tactic.epilogue_fusion_type
47424741 == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE];
4742+ TLLM_CHECK_WITH_INFO (tma_ws_input_template.isValid (), " TMA WS input template is not initialized" );
47434743 }
47444744
47454745 mInterface ->is_profiler = true ;
0 commit comments