Skip to content

Commit 90e748e

Browse files
committed
Code refactoring
Signed-off-by: Neta Zmora <[email protected]>
1 parent f6964d8 commit 90e748e

File tree

2 files changed

+271
-132
lines changed

2 files changed

+271
-132
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,26 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
7676
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
7777

7878

79-
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8moe_fused", mutates_args=())
80-
def trtllm_quant_fp8moe_fused(
79+
def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None:
80+
supported_combinations = {
81+
"gated_mlp": ["silu"],
82+
"mlp": ["relu2"],
83+
}
84+
supported_act_fns = [
85+
act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list
86+
]
87+
assert mlp_style in supported_combinations.keys(), (
88+
f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}."
89+
)
90+
assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}."
91+
assert act_fn in supported_combinations[mlp_style], (
92+
f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. "
93+
f"Supported combinations: {supported_combinations}"
94+
)
95+
96+
97+
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused", mutates_args=())
98+
def trtllm_quant_fp8_moe_fused(
8199
x: torch.Tensor,
82100
selected_experts: torch.Tensor,
83101
routing_weights: torch.Tensor,
@@ -110,21 +128,27 @@ def trtllm_quant_fp8moe_fused(
110128
w3_weight_scale: Weight scales for w3 [E]
111129
mlp_style: "gated_mlp" or "mlp"
112130
act_fn: "silu" for gated_mlp, "relu2" for mlp
131+
132+
Non-Gated MLP:
133+
activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
134+
135+
Gated MLP:
136+
activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
113137
"""
114138

115-
# if mlp_style != "gated_mlp":
116-
# raise NotImplementedError("FlashInfer FP8 MoE currently only supports gated_mlp")
139+
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
117140

118141
# Store original shape and flatten to 2D
119142
x_shape = x.shape
120143
x2d = x.view(-1, x_shape[-1])
121-
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale)
144+
# Quantize input
145+
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0])
122146

123147
# Scales are stored in float32
124148
w1_weight_scale = w1_weight_scale.to(torch.float32)
125149
w2_weight_scale = w2_weight_scale.to(torch.float32)
126-
w1_input_scale = w1_input_scale.to(torch.float32)
127-
w2_input_scale = w2_input_scale.to(torch.float32)
150+
w1_input_scale = w1_input_scale.to(torch.float32)[0]
151+
w2_input_scale = w2_input_scale.to(torch.float32)[0]
128152

129153
# Prepare quant_scales for TensorRT-LLM FP8 format:
130154
# [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
@@ -136,7 +160,7 @@ def trtllm_quant_fp8moe_fused(
136160

137161
# Compute combined scales
138162
gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze()
139-
gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) # [E]
163+
gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32)
140164
gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze()
141165
gemm1_input_dequant = w1_input_scale.contiguous()
142166

@@ -145,7 +169,7 @@ def trtllm_quant_fp8moe_fused(
145169
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant]
146170

147171
# Ensure contiguous tensors
148-
selected_experts = selected_experts.contiguous()
172+
selected_experts = selected_experts.int().contiguous()
149173
routing_weights = routing_weights.contiguous()
150174

151175
# Todo: refactor this repeating code block
@@ -194,8 +218,8 @@ def trtllm_quant_fp8moe_fused(
194218
return output[0].view(x_shape)
195219

196220

197-
@trtllm_quant_fp8moe_fused.register_fake
198-
def trtllm_quant_fp8moe_fused_fake(
221+
@trtllm_quant_fp8_moe_fused.register_fake
222+
def trtllm_quant_fp8_moe_fused_fake(
199223
x: torch.Tensor,
200224
selected_experts: torch.Tensor,
201225
routing_weights: torch.Tensor,
@@ -208,7 +232,7 @@ def trtllm_quant_fp8moe_fused_fake(
208232
w1_weight_scale: torch.Tensor,
209233
w2_weight_scale: torch.Tensor,
210234
w3_weight_scale: torch.Tensor,
211-
mlp_style: str = "gated_mlp",
212-
act_fn: str = "silu",
235+
mlp_style: str,
236+
act_fn: str,
213237
) -> torch.Tensor:
214238
return torch.empty_like(x)

0 commit comments

Comments
 (0)