Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ transforms:
fuse_moe:
stage: post_load_fusion
enabled: true
fuse_fp8_moe:
stage: post_load_fusion
enabled: true
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
Expand Down
243 changes: 179 additions & 64 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def torch_quant_fp8_moe(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
) -> torch.Tensor:
"""
FP8 MoE op using quantized linear operations.
Expand All @@ -239,40 +241,91 @@ def torch_quant_fp8_moe(
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
w1_weight:
List of per-expert weight tensors:
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
w2_weight:
List of per-expert weight tensors:
• gated_mlp: W2 with shape (H, I) — down projection.
• mlp: W_down with shape (H, I) — down projection.
w3_weight:
List of per-expert weight tensors:
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
• mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.

mlp_style:
Selects the per-expert MLP computation:
• "gated_mlp" (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
• "mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
"""

def make_fp8_mlp(i):
def mlp(inp):
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)

return mlp

mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()

if style == "gated_mlp":

def make_fp8_mlp(i):
def mlp(inp):
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)

return mlp

mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]

elif style == "mlp":

def make_fp8_mlp(i):
def mlp(inp):
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)

return mlp

mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]

else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

return _template_moe(x, selected_experts, routing_weights, mlps)


Expand All @@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)

Expand All @@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
) -> torch.Tensor:
"""
FP4 MoE op using quantized linear operations.
Expand All @@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe(
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
w1_weight:
List of per-expert weight tensors:
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
w2_weight:
List of per-expert weight tensors:
• gated_mlp: W2 with shape (H, I) — down projection.
• mlp: W_down with shape (H, I) — down projection.
w3_weight:
List of per-expert weight tensors:
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
• mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
mlp_style:
Selects the per-expert MLP computation:
• "gated_mlp" (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
• "mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
"""

def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)

return mlp

mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()

if style == "gated_mlp":

def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)

return mlp

mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]

elif style == "mlp":

def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)

return mlp

mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]

else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

return _template_moe(x, selected_experts, routing_weights, mlps)


Expand All @@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)

Expand Down
Loading