diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index a52f5e0dce..e8d92fb213 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1986,15 +1986,16 @@ Array trtllm_fp4_block_scale_moe( int hidden_size = hidden_states.size(1); if (hidden_states.dtype() == dl_uint8) hidden_size *= 2; - int hidden_states_scale_vec_size = -1; + int64_t hidden_states_scale_vec_size = -1; if (hidden_states_scale.has_value()) { - hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); + hidden_states_scale_vec_size = + (static_cast(num_tokens) * hidden_size) / hidden_states_scale.value().numel(); } int64_t intermediate_size_factor = isGatedActivation(static_cast(act_type)) ? 2 : 1; - int weight_scale_vec_size = - (local_num_experts * intermediate_size * intermediate_size_factor * hidden_size) / - gemm1_weights_scale.numel(); + int64_t weight_scale_vec_size = (static_cast(local_num_experts) * intermediate_size * + intermediate_size_factor * hidden_size) / + gemm1_weights_scale.numel(); TVM_FFI_ICHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size.";