From 358c62ed15c22e5a350e99debb3675f6ea31eaea Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:17:20 -0800 Subject: [PATCH 1/2] Add relu2 to kernel and python api Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../epilogue/thread/fused_activations.h | 24 ++++++++++++++++ .../kernels/cutlass_kernels/include/common.h | 1 + .../moe_gemm/moe_gemm_template_dispatch.h | 1 + .../cutlass_kernels/moe_gemm/moe_kernels.cu | 2 ++ .../_torch/custom_ops/torch_custom_ops.py | 28 +++++++++++++++++-- .../_torch/modules/fused_moe/routing.py | 12 ++++++++ 6 files changed, 65 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h index 5ce2f4e1daf..795de9a599a 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h @@ -59,6 +59,30 @@ __forceinline__ __device__ float tanh_opt(float x) #endif } +template +struct Relu2 +{ + static bool const kIsHeavy = false; + + CUTLASS_HOST_DEVICE + T operator()(T threshold, T value) const + { + ReLu relu_op; + multiplies mul; + T val = relu_op(threshold, value); + return mul(val, val); + } + + CUTLASS_HOST_DEVICE + T operator()(T value) const + { + ReLu relu_op; + multiplies mul; + T val = relu_op(value); + return mul(val, val); + } +}; + } // namespace thread } // namespace epilogue } // namespace cutlass diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h index 646be2575ca..55226c68960 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h @@ -28,6 +28,7 @@ enum class ActivationType Swiglu, Geglu, SwigluBias, + Relu2, Identity, InvalidType }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 2c0d1a94a53..477634cd3c0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -954,6 +954,7 @@ void MoeGemmRunner::moeGemmBiasAct( case ActivationType::Identity: runGemm(inputs, hopper_inputs); break; case ActivationType::Swiglu: runGemm(inputs, hopper_inputs); break; case ActivationType::Geglu: runGemm(inputs, hopper_inputs); break; + case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break; case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; default: TLLM_THROW("Invalid activation type."); break; } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 0fb56f3893d..383ad87e988 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -2307,6 +2307,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 decltype(block_scaling_type)::value>, // Geglu &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value>, // Relu2 &doActivationKernel, decltype(block_scaling_type)::value> // Identity diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 6c1a13bec8e..abaf9c967af 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -24,6 +24,22 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: torch.bmm(a, b, out=out) +from enum import IntEnum + + +# Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.hExpand commentComment on line R76Code has comments. Press enter to view. +class ActivationType(IntEnum): + Gelu = 0 + Relu = 1 + Silu = 2 + Swiglu = 3 + Geglu = 4 + SwigluBias = 5 + Relu2 = 6 + Identity = 7 + InvalidType = 8 + + class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() @@ -52,6 +68,7 @@ def __init__( use_mxfp8_act_scaling: bool, min_latency_mode: bool, use_fused_finalize: bool, + activation_type: ActivationType, unpadded_hidden_size: Optional[int] = None, ): self.x_dtype = x_dtype @@ -72,6 +89,7 @@ def __init__( self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode self.use_fused_finalize = use_fused_finalize + self.activation_type = activation_type self.unpadded_hidden_size = unpadded_hidden_size if unpadded_hidden_size is not None else 0 instance_key = (x_dtype, weight_dtype, output_dtype, @@ -84,7 +102,7 @@ def __init__( x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, use_int8_woq_per_channel, use_mxfp8_act_scaling, - use_fused_finalize) + use_fused_finalize) # , activation_type) self.fused_moe_runner = MoERunner.runner_dict[instance_key] def get_valid_tactics(self, inputs: List[torch.Tensor], @@ -117,6 +135,7 @@ def forward( gemm_idx, tactic, do_preparation, + self.activation_type, self.unpadded_hidden_size, ) @@ -153,6 +172,7 @@ def fused_moe( tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, + activation_type: ActivationType = ActivationType.Swiglu, unpadded_hidden_size: Optional[int] = None, out_tensor: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: @@ -189,6 +209,7 @@ def fused_moe( use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, use_fused_finalize=use_fused_finalize, + activation_type=activation_type, unpadded_hidden_size=unpadded_hidden_size, ) @@ -223,8 +244,8 @@ def fused_moe( swizzled_input_sf, swiglu_alpha, swiglu_beta, swiglu_limit, tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode, - [gemm_tactic_1, gemm_tactic_2], unpadded_hidden_size, - tuner_num_tokens, out_tensor) + [gemm_tactic_1, gemm_tactic_2], activation_type, + unpadded_hidden_size, tuner_num_tokens, out_tensor) return output if min_latency_mode else [output] @@ -260,6 +281,7 @@ def _(input: torch.Tensor, tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, + activation_type: ActivationType = ActivationType.Swiglu, unpadded_hidden_size: Optional[int] = None, out_tensor: Optional[torch.Tensor] = None): seq_len = input.shape[0] diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 886d3a97b7d..4a134cd4e36 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -155,6 +155,18 @@ class RoutingMethodType(IntEnum): Unspecified = 5. +class ActivationType(IntEnum): + Gelu = 0 + Relu = 1 + Silu = 2 + Swiglu = 3 + Geglu = 4 + SwigluBias = 5 + Relu2 = 6 + Identity = 7 + InvalidType = 8 + + class BaseMoeRoutingMethod(nn.Module): def apply(self, _router_logits) -> (torch.Tensor, torch.Tensor): From 130dbd13a1c21a04aa42644408ed60dd6cffa3ea Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:54:19 -0800 Subject: [PATCH 2/2] Fixes and UT Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- cpp/tensorrt_llm/thop/moeOp.cpp | 32 +- .../custom_ops/fused_moe/trtllm_moe.py | 191 ++++++++++- .../_torch/custom_ops/torch_custom_ops.py | 4 +- .../singlegpu/custom_ops/test_trtllm_moe.py | 301 ++++++++++++++++++ 4 files changed, 507 insertions(+), 21 deletions(-) create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index fbed602d464..3cfaaff69a5 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, torch::optional> const& profile_ids, - torch::optional const& unpadded_hidden_size, torch::optional const& num_valid_tokens, - torch::optional const& out_tensor) + torch::optional const& activation_type, torch::optional const& unpadded_hidden_size, + torch::optional const& num_valid_tokens, torch::optional const& out_tensor) { std::lock_guard lock(mMutex); // Free the profile workspace to save memory @@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0], "fc1_expert_weights and fc2_expert_weights must have the same number of experts."); + ActivationType base_activation_type = activation_type.has_value() + ? static_cast(activation_type.value()) + : ActivationType::Swiglu; if (mUseINT8WoqPerChannel) { // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: @@ -337,8 +340,19 @@ class FusedMoeRunner : public torch::CustomClassHolder } else { - TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, - "fc1_expert_weights inter size must be fc2_expert_weights inter size."); + // TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, + // "fc1_expert_weights inter size must be fc2_expert_weights inter size."); + + if (isGatedActivation(base_activation_type)) + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier, + "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."); + } } int experts_per_token = token_selected_experts.sizes()[1]; @@ -375,7 +389,7 @@ class FusedMoeRunner : public torch::CustomClassHolder int const num_experts_on_rank = fc2_expert_weights.sizes()[0]; auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); - ActivationType base_activation_type = ActivationType::Swiglu; + if (swiglu_alpha.has_value()) { CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); @@ -474,8 +488,8 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, torch::optional> const& profile_ids, - torch::optional const& unpadded_hidden_size, torch::optional const& num_valid_tokens, - torch::optional const& out_tensor) + torch::optional const& activation_type, torch::optional const& unpadded_hidden_size, + torch::optional const& num_valid_tokens, torch::optional const& out_tensor) { std::lock_guard lock(mMutex); @@ -541,7 +555,9 @@ class FusedMoeRunner : public torch::CustomClassHolder auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank); - ActivationType base_activation_type = ActivationType::Swiglu; + ActivationType base_activation_type = activation_type.has_value() + ? static_cast(activation_type.value()) + : ActivationType::Swiglu; if (swiglu_alpha.has_value()) { CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index a14d0f436e5..901e3bb5cd4 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -1,5 +1,7 @@ import torch +from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType + @torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=()) def trtllm_fused_moe( @@ -8,6 +10,8 @@ def trtllm_fused_moe( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: x_shape = x.shape x = x.view(-1, x_shape[-1]) @@ -16,21 +20,38 @@ def trtllm_fused_moe( selected_experts = selected_experts.to(torch.int32) quant_scales = [] + # Determine activation type + mlp_style = mlp_style.lower() + act_fn = act_fn.lower() + + activation_type = ActivationType.Swiglu + if mlp_style == "gated_mlp": + # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) + if act_fn == "silu": + # activation_type = ActivationType.Silu + activation_type = ActivationType.Swiglu # need to fix this in trtllm + else: + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + if act_fn == "relu2": + activation_type = ActivationType.Relu2 + else: + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + return torch.ops.trtllm.fused_moe( x, selected_experts, routing_weights, - w3_w1_stacked_weight, - None, # w3_w1_stacked_bias - w2_stacked_weight, - None, # w2_stacked_bias - x.dtype, - quant_scales, - tp_size=1, - tp_rank=0, - ep_size=1, - ep_rank=0, - enable_alltoall=False, + fc1_expert_weights=w3_w1_stacked_weight, + fc1_expert_biases=None, + fc2_expert_weights=w2_stacked_weight, + fc2_expert_biases=None, + output_dtype=x.dtype, + quant_scales=quant_scales, + activation_type=activation_type, )[0].view(x_shape) @@ -41,5 +62,153 @@ def trtllm_fused_moe( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) + + +# Todo: refactor this repeating code block +def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" + FP8_MIN = torch.finfo(torch.float8_e4m3fn).min + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + + +@torch.library.custom_op("auto_deploy::trtllm_quant_fp8moe_fused", mutates_args=()) +def trtllm_quant_fp8moe_fused( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights + w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights + w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp + w1_input_scale: torch.Tensor, # [E] stacked input scales + w2_input_scale: torch.Tensor, # [E] stacked input scales + w3_input_scale: torch.Tensor, # [E] or unused + w1_weight_scale: torch.Tensor, # [E] stacked weight scales + w2_weight_scale: torch.Tensor, # [E] stacked weight scales + w3_weight_scale: torch.Tensor, # [E] or unused + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + """ + TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP. + Parameters: + x: BF16/FP16 input tensor of shape (B, H) or (B, S, H) + selected_experts: Expert indices (B*S, TOP_K) + routing_weights: Routing weights (B*S, TOP_K) + w1_weight: FP8 w1 weights [E, I, H] + w2_weight: FP8 w2 weights [E, H, I] + w3_weight: FP8 w3 weights [E, I, H] (for gated_mlp) + w1_input_scale: Input scales for w1 [E] + w2_input_scale: Input scales for w2 [E] + w3_input_scale: Input scales for w3 [E] + w1_weight_scale: Weight scales for w1 [E] + w2_weight_scale: Weight scales for w2 [E] + w3_weight_scale: Weight scales for w3 [E] + mlp_style: "gated_mlp" or "mlp" + act_fn: "silu" for gated_mlp, "relu2" for mlp + """ + + # if mlp_style != "gated_mlp": + # raise NotImplementedError("FlashInfer FP8 MoE currently only supports gated_mlp") + + # Store original shape and flatten to 2D + x_shape = x.shape + x2d = x.view(-1, x_shape[-1]) + x_q_fp8 = _quantize_fp8(x2d, w1_input_scale) + + # Scales are stored in float32 + w1_weight_scale = w1_weight_scale.to(torch.float32) + w2_weight_scale = w2_weight_scale.to(torch.float32) + w1_input_scale = w1_input_scale.to(torch.float32) + w2_input_scale = w2_input_scale.to(torch.float32) + + # Prepare quant_scales for TensorRT-LLM FP8 format: + # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale] + # For gated MLP: + # - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3) + # - gemm2_act_quant_scale: 1 / w2_input_scale + # - gemm2_dequant_scale: w2_weight_scale * w2_input_scale + # - gemm1_input_dequant_scale: w1_input_scale + + # Compute combined scales + gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() + gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) # [E] + gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() + gemm1_input_dequant = w1_input_scale.contiguous() + + assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D" + assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D" + quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant] + + # Ensure contiguous tensors + selected_experts = selected_experts.contiguous() + routing_weights = routing_weights.contiguous() + + # Todo: refactor this repeating code block + + # Determine activation type + mlp_style = mlp_style.lower() + act_fn = act_fn.lower() + + activation_type = ActivationType.Swiglu + if mlp_style == "gated_mlp": + # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) + # For gated MLP, concatenate w1 and w3 + # TensorRT-LLM expects [w3, w1] concatenated + w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H] + fc1_expert_weights = w3_w1_stacked + if act_fn == "silu": + # activation_type = ActivationType.Silu + activation_type = ActivationType.Swiglu # need to fix this in trtllm + else: + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + fc1_expert_weights = w1_weight.contiguous() + if act_fn == "relu2": + activation_type = ActivationType.Relu2 + else: + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + + # Note! Outputting Float8_e4m3fn directly is not currently supported + output = torch.ops.trtllm.fused_moe( + x_q_fp8, + selected_experts, + routing_weights, + fc1_expert_weights=fc1_expert_weights, + fc1_expert_biases=None, + fc2_expert_weights=w2_weight.contiguous(), + fc2_expert_biases=None, + output_dtype=x.dtype, + quant_scales=quant_scales, + activation_type=activation_type, + ) + + # Restore original shape + return output[0].view(x_shape) + + +@trtllm_quant_fp8moe_fused.register_fake +def trtllm_quant_fp8moe_fused_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, + w2_weight: torch.Tensor, + w3_weight: torch.Tensor, + w1_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w3_input_scale: torch.Tensor, + w1_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index abaf9c967af..48f34618643 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -172,13 +172,13 @@ def fused_moe( tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, - activation_type: ActivationType = ActivationType.Swiglu, + activation_type: int = ActivationType.Swiglu, unpadded_hidden_size: Optional[int] = None, out_tensor: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: tuner = AutoTuner.get() - + activation_type = ActivationType(activation_type) # Only the non-alltoall case is considered for profiling in the warmup phase. # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. if enable_alltoall: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py new file mode 100644 index 00000000000..a4d5c400f0d --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -0,0 +1,301 @@ +import pytest +import torch +from torch.nn import functional as F + +from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType + +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FP8_DTYPE = torch.float8_e4m3fn + +# ACT_FUNC = "relu2" +# ACT_FUNC = "swiglu" +# ACT_FUNC = "silu" +# TEST_DTYPE = torch.float16 +# TEST_DTYPE = torch.bfloat16 +# TEST_DTYPE = torch.float8_e4m3fn +# O_TYPE = torch.bfloat16 if TEST_DTYPE == torch.float8_e4m3fn else TEST_DTYPE + + +def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: + fp8_traits_max = FLOAT8_E4M3_MAX + fp8_traits_min = -FLOAT8_E4M3_MAX + fp8_max = torch.tensor(fp8_traits_max).float() + one = torch.tensor(1.0).float() + + x_max = x.abs().max().float() + scale = x_max / fp8_max + iscale = one / scale + out = (x.float() * iscale).clamp(fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) + return out, scale.view((1,)) + + +def gen_tensor(shape, dtype, stype=None, scale=1.0): + x = torch.randn(*shape, dtype=dtype).cuda() * scale # * 0.1 + return x.to(stype) if stype else x + + +def cast_to_representable(x): + x_q, x_scale = dynamic_per_tensor_fp8_quant(x) + x = x_q.to(x.dtype) * x_scale.to(x.dtype) + return x + + +def compute_routing(router_logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute routing weights and selected experts from router logits. + + Args: + router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts] + top_k (int): Number of experts to route to per token + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - routing_weights: Expert weights of shape [batch_size, top_k] + - selected_experts: Expert indices of shape [batch_size, top_k] + """ + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.float() + return routing_weights, selected_experts + + +def compute_with_experts( + num_experts, + x, + w31_weight, + w2_weight, + selected_experts, + routing_weights, + alpha=None, + beta=None, + limit=None, + activation_func="silu", +): + def relu2(x: torch.Tensor) -> torch.Tensor: + return torch.square(F.relu(x)) + + results = torch.zeros_like(x) + for expert_id in range(num_experts): + mask = selected_experts == expert_id + if not mask.sum(): + continue + batch_idx, nth_expert = torch.where(mask) + w31_expert = w31_weight[expert_id] # [2 * intermediate_size, hidden_size] + w2_expert = w2_weight[expert_id] # [hidden_size, intermediate_size] + + # Split w13 into w1 and w3 + w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0) + + expert_inputs = x[batch_idx] + if alpha is not None and limit is not None and beta is not None: + # SwiGLUBias + x1 = expert_inputs @ w1_expert.t() + x1 = x1.clamp_(min=None, max=limit) + x1_scaled = x1 * torch.sigmoid(alpha * x1) + x2 = expert_inputs @ w3_expert.t() + x2 = x2.clamp_(min=-limit, max=limit) + beta + + inter = x1_scaled * x2 + else: + if activation_func == "swiglu" or activation_func == "silu": + inter = F.silu(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) + else: + inter = relu2(expert_inputs @ w1_expert.t()) + + output = inter @ w2_expert.t() + results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output + return results.view_as(x) + + +# Test configurations +BATCH_SIZES = [ + 1, +] +HIDDEN_SIZES = [ + 128, +] +NUM_EXPERTS = [2] +TOP_K_VALUES = [2] +INTERMEDIATE_SIZES = [ + 128, +] +EP_NUM_EXPERTS = [8] +EP_TOP_K = [2] + + +TEST_DTYPES = [ + # Todo: separate to two tests: float and fp8 + # (torch.float16, torch.float16, torch.float16), + # (torch.bfloat16, torch.bfloat16, torch.bfloat16), + (torch.float16, torch.float16, torch.float8_e4m3fn), + (torch.bfloat16, torch.bfloat16, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.bfloat16, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.float16, torch.float8_e4m3fn), +] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("itype, otype, wtype", TEST_DTYPES) +@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +def test_trtllm_cutlass_fused_moe( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + itype, + otype, + wtype, + activation_func, +): + # Skip invalid configurations + if top_k > num_experts: + pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") + + torch.manual_seed(42) + input_shape = (batch_size, hidden_size) + w31_shape = (num_experts, 2 * intermediate_size, hidden_size) + w2_shape = (num_experts, hidden_size, intermediate_size) + if activation_func in ["swiglu", "silu"]: + X_GEN_SCALE = 1.0 + else: + X_GEN_SCALE = 0.5 + + x = cast_to_representable(gen_tensor(input_shape, otype, scale=X_GEN_SCALE)) + router_logits = gen_tensor((batch_size, num_experts), otype) + + # Create weight tensors + w31_weight = gen_tensor(w31_shape, otype, wtype) + w2_weight = gen_tensor(w2_shape, otype, wtype) + w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda() + w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda() + + w31_dequantized = gen_tensor(w31_shape, otype) + w2_dequantized = gen_tensor(w2_shape, otype) + for expert_id in range(num_experts): + w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) + w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) + + w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) + w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) + + w31_weight.data[expert_id].copy_(w31_quant) + w2_weight.data[expert_id].copy_(w2_quant) + w31_scales.data[expert_id].copy_(s31) + w2_scales.data[expert_id].copy_(s2) + w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31)) + w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2)) + + routing_weights, selected_experts = compute_routing(router_logits, top_k) + ref_output = compute_with_experts( + num_experts, + x, + w31_dequantized, + w2_dequantized, + selected_experts, + routing_weights, + activation_func=activation_func, + ) + + # For fp8, the hidden_state expects quantized. + w3_scales, w1_scales = torch.chunk(w31_scales, 2, dim=-1) + + x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x) + hidden_states_scale = hidden_states_scale[0].detach().clone().cuda() + + w3_input_scale = torch.tensor(1.0).cuda() + w2_input_scale = torch.tensor(1.0).cuda() + quant_scales = [ + torch.squeeze(w1_scales * hidden_states_scale).float(), # gemm1 dequant scale + w3_input_scale, # gemm2 activation quant scale + torch.squeeze(1.0 * w2_scales).float(), # gemm2 dequant scale + hidden_states_scale, # gemm1 input dequant scale + ] + + # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) + w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) + w3_weight_dequantized, w1_weight_dequantized = torch.chunk(w31_dequantized, 2, dim=1) + torch.cuda.synchronize() + print("before fused_moe.cutlass_fused_moe") + activation_type = ( + ActivationType.Swiglu if activation_func in ["swiglu", "silu"] else ActivationType.Relu2 + ) + + if itype == torch.bfloat16 or itype == torch.float16: + ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused( + x, + selected_experts.to(torch.int), + routing_weights, + w3_w1_stacked_weight=w1_weight_dequantized.contiguous() + if activation_func == "relu2" + else w31_dequantized, + w2_stacked_weight=w2_dequantized, + mlp_style="mlp" if activation_func == "relu2" else "gated_mlp", + act_fn=activation_func, + ) + trtllm_test_output = torch.ops.trtllm.fused_moe( + x, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=w1_weight_dequantized.contiguous() + if activation_func == "relu2" + else w31_dequantized, + fc1_expert_biases=None, + fc2_expert_weights=w2_dequantized, + fc2_expert_biases=None, + output_dtype=otype, + quant_scales=quant_scales, + activation_type=activation_type, + )[0].view(input_shape) + + elif itype == torch.float8_e4m3fn: + # FP8 + ad_test_output = torch.ops.auto_deploy.trtllm_quant_fp8moe_fused( + x, # Note! unquantized input is expected + selected_experts.to(torch.int), + routing_weights, + w1_weight=w1_weight.contiguous(), + w2_weight=w2_weight.contiguous(), + w3_weight=w3_weight.contiguous(), + w1_input_scale=hidden_states_scale, + w2_input_scale=w2_input_scale, + w3_input_scale=w3_input_scale, + w1_weight_scale=w1_scales, + w2_weight_scale=w2_scales, + w3_weight_scale=w3_scales, + mlp_style="mlp" if activation_func == "relu2" else "gated_mlp", + act_fn=activation_func, + ) + + trtllm_test_output = torch.ops.trtllm.fused_moe( + x_quant, # Note! quantized input is expected + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=w1_weight.contiguous() if activation_func == "relu2" else w31_weight, + fc1_expert_biases=None, + fc2_expert_weights=w2_weight, + fc2_expert_biases=None, + output_dtype=otype, + quant_scales=quant_scales, + activation_type=activation_type, + )[0].view(input_shape) + torch.cuda.synchronize() + + diff = (ref_output - ad_test_output).abs() + print(f"max diff: {diff.max()}") + assert trtllm_test_output is not None + # torch.testing.assert_close(ad_test_output, trtllm_test_output, rtol=1e-6, atol=1e-6) + + if diff.max() > 1e-1: + print("diff: " + "-" * 20) + print(f"{diff[:10]}") + print("test_output: " + "-" * 20) + print(f"{ad_test_output[:10]}") + print("ref_output: " + "-" * 20) + print(f"{ref_output[:10]}") + torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1)