Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ transforms:
stage: pattern_matcher
match_dense_moe_pattern:
stage: pattern_matcher
match_bmm_moe_pattern:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
run_shape_prop: true
Expand Down
118 changes: 92 additions & 26 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,20 @@ def _template_moe(
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
mlps: List[Callable[[torch.Tensor], torch.Tensor]],
apply_routing_on_input: bool = False,
) -> torch.Tensor:
"""Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info."""
"""Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info.

Args:
x: Input tensor
selected_experts: Expert indices for each token
routing_weights: Routing weights for each expert
mlps: List of MLP functions
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP (BMM-based pattern).
This means: silu(input * routing_weight)
If False, multiply routing weights with OUTPUT after MLP (standard pattern).
This means: silu(input) * routing_weight
"""
x_shape = x.shape
hidden_dim = x_shape[-1]
x = x.view(-1, hidden_dim)
Expand All @@ -54,8 +66,18 @@ def _template_moe(
if not tokens_for_this_expert.shape[0]:
continue # input of shape [0, hidden_dim] breaks fp4 kernel

expert_out = mlps[expert_idx](tokens_for_this_expert)
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
if apply_routing_on_input:
# INPUT-SIDE routing (BMM-based pattern): multiply routing weights with INPUT
# Result: silu(input * routing_weight)
scaled_input = tokens_for_this_expert * routing_weights[top_x, idx, None]
expert_out = mlps[expert_idx](scaled_input)
current_hidden_states = expert_out
else:
# OUTPUT-SIDE routing (standard pattern): multiply routing weights with OUTPUT
# Result: silu(input) * routing_weight
expert_out = mlps[expert_idx](tokens_for_this_expert)
current_hidden_states = expert_out * routing_weights[top_x, idx, None]

final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(final_hidden_states.dtype)
)
Expand All @@ -72,46 +94,88 @@ def torch_moe(
w3_weight: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
apply_routing_on_input: bool = False,
) -> torch.Tensor:
"""
Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch
(token routing + index_add_ accumulation) and a selectable per-expert MLP.

Supports both:
- Standard MoE with per-expert weight lists (apply_routing_on_input=False)
- Llama4 MoE with stacked weight tensors (apply_routing_on_input=True)

Parameters:
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
S is the sequence length, and H is the hidden size.
selected_experts (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the indices
of the selected experts for each token. Only experts within range [0,num_experts) is processed
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
routing weights for the selected experts.
- Standard MoE: softmax normalized weights
- Llama4 MoE: sigmoid activated weights
w1_weight:
List of per-expert weight tensors:
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
For per-expert lists:
• mlp_style=="gated_mlp": List of W1 with shape (I, H) — "gate" projection.
• mlp_style=="mlp": List of W_up with shape (I, H) — up projection.
For stacked tensors (Llama4):
• Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
w2_weight:
List of per-expert weight tensors:
• gated_mlp: W2 with shape (H, I) — down projection.
• mlp: W_down with shape (H, I) — down projection.
For per-expert lists:
• List of W2/W_down with shape (H, I) — down projection.
For stacked tensors (Llama4):
• Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format
w3_weight:
List of per-expert weight tensors:
• gated_mlp: W3 with shape (I, H) — “up” (second) projection in gated MLP.
• mlp: pass an empty list []; ignored.
For per-expert lists with gated_mlp:
• List of W3 with shape (I, H) — "up" (second) projection in gated MLP.
For mlp style or stacked tensors:
• pass an empty list []; ignored.
mlp_style:
Selects the per-expert MLP computation:
• "gated_mlp" (default, Mixtral/DeepSeek-style):
• "gated_mlp" (default, Mixtral/DeepSeek/Llama4-style):
y = W2( act(W1 x) * (W3 x) )
• "mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
apply_routing_on_input:
If True (Llama4 pattern): multiply routing weights with INPUT before MLP
Result: act(input * routing_weight) - routing affects activation
If False (standard pattern): multiply routing weights with OUTPUT after MLP
Result: act(input) * routing_weight - routing scales output
Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()

if style == "gated_mlp":
# Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3

if is_stacked:
# Llama4 stacked tensor format - only supports gated_mlp
if style != "gated_mlp":
raise ValueError("Stacked tensor format only supports 'gated_mlp' style")

w3_w1_stacked = w1_weight[0] # (E, 2*I, H)
w2_stacked = w2_weight[0] # (E, H, I)

def make_mlp(i: int):
gate_up = w3_w1_stacked[i] # (2*I, H)
intermediate_size = gate_up.shape[0] // 2
W3 = gate_up[:intermediate_size, :] # (I, H)
W1 = gate_up[intermediate_size:, :] # (I, H)
W2 = w2_stacked[i] # (H, I)
weight_dtype = W1.dtype
return lambda inp: F.linear(
act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3),
W2,
)

mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])]

elif style == "gated_mlp":
# Standard per-expert list format with gated MLP
def make_mlp(i: int):
W1 = w1_weight[i] # (I, H)
W2 = w2_weight[i] # (H, I)
Expand All @@ -121,7 +185,7 @@ def make_mlp(i: int):
mlps = [make_mlp(i) for i in range(len(w1_weight))]

elif style == "mlp":

# Standard per-expert list format with simple MLP
def make_mlp(i: int):
W_up = w1_weight[i] # (I, H)
W_down = w2_weight[i] # (H, I)
Expand All @@ -132,7 +196,7 @@ def make_mlp(i: int):
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

return _template_moe(x, selected_experts, routing_weights, mlps)
return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input)


@torch_moe.register_fake
Expand All @@ -145,6 +209,7 @@ def torch_moe_fake(
w3_weight: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
apply_routing_on_input: bool = False,
) -> torch.Tensor:
return torch.empty_like(x)

Expand All @@ -159,23 +224,29 @@ def torch_fused_moe(
) -> torch.Tensor:
"""
A reference implementation of a fused MoE layer computation.

All weights are in TRT-LLM format (conversion from Llama4 happens during graph transformation).

Parameters:
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
S is the sequence length, and H is the hidden size.
selected_experts (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the
indices of the selected experts for each token.
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
routing weights for the selected experts.
w3_w1_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
containing the fused weights for w3 and w1 for each expert.
w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
containing the weights for w2 for each expert.
w3_w1_stacked_weight (torch.Tensor): Stacked gate/up weights in TRT-LLM format:
(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
w2_stacked_weight (torch.Tensor): Stacked down weights in TRT-LLM format:
(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
x_shape = x.shape
x = x.view(-1, x_shape[-1])
num_experts = w2_stacked_weight.shape[0]

# Standardized on TRT-LLM format (conversion happens during graph transformation)
# TRT-LLM format: gate_up is (2*I, H), down is (H, I)
intermediate_size = w3_w1_stacked_weight.shape[1] // 2
results = torch.zeros_like(x)

Expand All @@ -190,12 +261,7 @@ def torch_fused_moe(
w3 = stacked[:intermediate_size, :]
w1 = stacked[intermediate_size:, :]
w2 = w2_stacked_weight[expert_id]

# Compute expert output:
# expert_out = (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
out_w1 = expert_inputs @ w1.t()
out_w3 = expert_inputs @ w3.t()
expert_out = (F.silu(out_w1) * out_w3) @ w2.t()
expert_out = (F.silu(expert_inputs @ w1.t()) * (expert_inputs @ w3.t())) @ w2.t()

scaling = routing_weights[batch_idx, nth_expert].unsqueeze(-1)
results[batch_idx] += scaling * expert_out
Expand Down
Loading