Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ transforms:
fuse_moe:
stage: post_load_fusion
enabled: true
fuse_fp8_moe:
stage: post_load_fusion
enabled: true
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
Expand Down
243 changes: 179 additions & 64 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def torch_quant_fp8_moe(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
) -> torch.Tensor:
"""
FP8 MoE op using quantized linear operations.
Expand All @@ -239,40 +241,91 @@ def torch_quant_fp8_moe(
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
w1_weight:
List of per-expert weight tensors:
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
w2_weight:
List of per-expert weight tensors:
• gated_mlp: W2 with shape (H, I) — down projection.
• mlp: W_down with shape (H, I) — down projection.
w3_weight:
List of per-expert weight tensors:
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
• mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.

mlp_style:
Selects the per-expert MLP computation:
• "gated_mlp" (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
• "mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
"""

def make_fp8_mlp(i):
def mlp(inp):
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)

return mlp

mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()

if style == "gated_mlp":

def make_fp8_mlp(i):
def mlp(inp):
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)

return mlp

mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]

elif style == "mlp":

def make_fp8_mlp(i):
def mlp(inp):
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)

return mlp

mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]

else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

return _template_moe(x, selected_experts, routing_weights, mlps)


Expand All @@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)

Expand All @@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
) -> torch.Tensor:
"""
FP4 MoE op using quantized linear operations.
Expand All @@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe(
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
w1_weight:
List of per-expert weight tensors:
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
w2_weight:
List of per-expert weight tensors:
• gated_mlp: W2 with shape (H, I) — down projection.
• mlp: W_down with shape (H, I) — down projection.
w3_weight:
List of per-expert weight tensors:
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
• mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
mlp_style:
Selects the per-expert MLP computation:
• "gated_mlp" (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
• "mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
"""

def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)

return mlp

mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()

if style == "gated_mlp":

def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)

return mlp

mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]

elif style == "mlp":

def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)

return mlp

mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]

else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

return _template_moe(x, selected_experts, routing_weights, mlps)


Expand All @@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)

Expand Down
Loading