@@ -76,8 +76,26 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
7676 return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
7777
7878
79- @torch .library .custom_op ("auto_deploy::trtllm_quant_fp8moe_fused" , mutates_args = ())
80- def trtllm_quant_fp8moe_fused (
79+ def _validate_mlp_style_and_act_fn (mlp_style : str , act_fn : str ) -> None :
80+ supported_combinations = {
81+ "gated_mlp" : ["silu" ],
82+ "mlp" : ["relu2" ],
83+ }
84+ supported_act_fns = [
85+ act_fn for act_fn_list in supported_combinations .values () for act_fn in act_fn_list
86+ ]
87+ assert mlp_style in supported_combinations .keys (), (
88+ f"Unknown mlp_style '{ mlp_style } '. Use { supported_combinations .keys ()} ."
89+ )
90+ assert act_fn in supported_act_fns , f"Unknown act_fn '{ act_fn } '. Use { supported_act_fns } ."
91+ assert act_fn in supported_combinations [mlp_style ], (
92+ f"Unsupported combination: mlp_style='{ mlp_style } ', act_fn='{ act_fn } '. "
93+ f"Supported combinations: { supported_combinations } "
94+ )
95+
96+
97+ @torch .library .custom_op ("auto_deploy::trtllm_quant_fp8_moe_fused" , mutates_args = ())
98+ def trtllm_quant_fp8_moe_fused (
8199 x : torch .Tensor ,
82100 selected_experts : torch .Tensor ,
83101 routing_weights : torch .Tensor ,
@@ -110,21 +128,27 @@ def trtllm_quant_fp8moe_fused(
110128 w3_weight_scale: Weight scales for w3 [E]
111129 mlp_style: "gated_mlp" or "mlp"
112130 act_fn: "silu" for gated_mlp, "relu2" for mlp
131+
132+ Non-Gated MLP:
133+ activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
134+
135+ Gated MLP:
136+ activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
113137 """
114138
115- # if mlp_style != "gated_mlp":
116- # raise NotImplementedError("FlashInfer FP8 MoE currently only supports gated_mlp")
139+ _validate_mlp_style_and_act_fn (mlp_style , act_fn )
117140
118141 # Store original shape and flatten to 2D
119142 x_shape = x .shape
120143 x2d = x .view (- 1 , x_shape [- 1 ])
121- x_q_fp8 = _quantize_fp8 (x2d , w1_input_scale )
144+ # Quantize input
145+ x_q_fp8 = _quantize_fp8 (x2d , w1_input_scale [0 ])
122146
123147 # Scales are stored in float32
124148 w1_weight_scale = w1_weight_scale .to (torch .float32 )
125149 w2_weight_scale = w2_weight_scale .to (torch .float32 )
126- w1_input_scale = w1_input_scale .to (torch .float32 )
127- w2_input_scale = w2_input_scale .to (torch .float32 )
150+ w1_input_scale = w1_input_scale .to (torch .float32 )[ 0 ]
151+ w2_input_scale = w2_input_scale .to (torch .float32 )[ 0 ]
128152
129153 # Prepare quant_scales for TensorRT-LLM FP8 format:
130154 # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
@@ -136,7 +160,7 @@ def trtllm_quant_fp8moe_fused(
136160
137161 # Compute combined scales
138162 gemm1_dequant = (w1_weight_scale * w1_input_scale ).contiguous ().squeeze ()
139- gemm2_act_quant = (1.0 / w2_input_scale ).contiguous ().to (torch .float32 ) # [E]
163+ gemm2_act_quant = (1.0 / w2_input_scale ).contiguous ().to (torch .float32 )
140164 gemm2_dequant = (w2_weight_scale * w2_input_scale ).contiguous ().squeeze ()
141165 gemm1_input_dequant = w1_input_scale .contiguous ()
142166
@@ -145,7 +169,7 @@ def trtllm_quant_fp8moe_fused(
145169 quant_scales = [gemm1_dequant , gemm2_act_quant , gemm2_dequant , gemm1_input_dequant ]
146170
147171 # Ensure contiguous tensors
148- selected_experts = selected_experts .contiguous ()
172+ selected_experts = selected_experts .int (). contiguous ()
149173 routing_weights = routing_weights .contiguous ()
150174
151175 # Todo: refactor this repeating code block
@@ -194,8 +218,8 @@ def trtllm_quant_fp8moe_fused(
194218 return output [0 ].view (x_shape )
195219
196220
197- @trtllm_quant_fp8moe_fused .register_fake
198- def trtllm_quant_fp8moe_fused_fake (
221+ @trtllm_quant_fp8_moe_fused .register_fake
222+ def trtllm_quant_fp8_moe_fused_fake (
199223 x : torch .Tensor ,
200224 selected_experts : torch .Tensor ,
201225 routing_weights : torch .Tensor ,
@@ -208,7 +232,7 @@ def trtllm_quant_fp8moe_fused_fake(
208232 w1_weight_scale : torch .Tensor ,
209233 w2_weight_scale : torch .Tensor ,
210234 w3_weight_scale : torch .Tensor ,
211- mlp_style : str = "gated_mlp" ,
212- act_fn : str = "silu" ,
235+ mlp_style : str ,
236+ act_fn : str ,
213237) -> torch .Tensor :
214238 return torch .empty_like (x )
0 commit comments