Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ transforms:
stage: pattern_matcher
match_dense_moe_pattern:
stage: pattern_matcher
match_bmm_moe_pattern:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
run_shape_prop: true
Expand Down
120 changes: 106 additions & 14 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,20 @@ def _template_moe(
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
mlps: List[Callable[[torch.Tensor], torch.Tensor]],
apply_routing_on_input: bool = False,
) -> torch.Tensor:
"""Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info."""
"""Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info.

Args:
x: Input tensor
selected_experts: Expert indices for each token
routing_weights: Routing weights for each expert
mlps: List of MLP functions
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP (BMM-based pattern).
This means: silu(input * routing_weight)
If False, multiply routing weights with OUTPUT after MLP (standard pattern).
This means: silu(input) * routing_weight
"""
x_shape = x.shape
hidden_dim = x_shape[-1]
x = x.view(-1, hidden_dim)
Expand All @@ -54,8 +66,18 @@ def _template_moe(
if not tokens_for_this_expert.shape[0]:
continue # input of shape [0, hidden_dim] breaks fp4 kernel

expert_out = mlps[expert_idx](tokens_for_this_expert)
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
if apply_routing_on_input:
# INPUT-SIDE routing (BMM-based pattern): multiply routing weights with INPUT
# Result: silu(input * routing_weight) - routing affects activation
Comment thread
tcherckez-nvidia marked this conversation as resolved.
Outdated
scaled_input = tokens_for_this_expert * routing_weights[top_x, idx, None]
expert_out = mlps[expert_idx](scaled_input)
current_hidden_states = expert_out
else:
# OUTPUT-SIDE routing (standard pattern): multiply routing weights with OUTPUT
# Result: silu(input) * routing_weight - routing applied after activation
expert_out = mlps[expert_idx](tokens_for_this_expert)
current_hidden_states = expert_out * routing_weights[top_x, idx, None]

final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(final_hidden_states.dtype)
)
Expand Down Expand Up @@ -93,7 +115,7 @@ def torch_moe(
• mlp: W_down with shape (H, I) — down projection.
w3_weight:
List of per-expert weight tensors:
• gated_mlp: W3 with shape (I, H) — “up” (second) projection in gated MLP.
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
• mlp: pass an empty list []; ignored.
mlp_style:
Selects the per-expert MLP computation:
Expand Down Expand Up @@ -149,6 +171,75 @@ def torch_moe_fake(
return torch.empty_like(x)


@torch.library.custom_op("auto_deploy::torch_moe_bmm", mutates_args=())
def torch_moe_bmm(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked: torch.Tensor,
w2_stacked: torch.Tensor,
act_fn: str = "silu",
apply_routing_on_input: bool = True,
) -> torch.Tensor:
"""MoE operator using stacked weights for expert computation.

Only supports gated MLP (SwiGLU-style) architecture.
Uses stacked expert weights in TRT-LLM format (conversion from Llama4 happens during graph transformation).

Weights are expected in TRT-LLM format:
- w3_w1_stacked: (NUM_EXPERTS, 2*INTERMEDIATE_SIZE, HIDDEN_SIZE)
- w2_stacked: (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)

Args:
x: Input tensor
selected_experts: Expert indices for each token
routing_weights: Routing weights for each selected expert
w3_w1_stacked: Stacked gate/up weights (TRT-LLM format)
w2_stacked: Stacked down weights (TRT-LLM format)
act_fn: Activation function name (default: "silu")
apply_routing_on_input: If True, apply routing to INPUT (default for Llama4 pattern).
If False, apply routing to OUTPUT (alternative pattern).
"""

act_fn = _resolve_activation(act_fn)

# Standardized on TRT-LLM format (conversion happens during graph transformation)
# Only gated MLP supported: output = (act_fn(x @ W1) * (x @ W3)) @ W2
def make_mlp(i: int):
gate_up = w3_w1_stacked[i] # (2*I, H)
intermediate_size = gate_up.shape[0] // 2
W3 = gate_up[:intermediate_size, :] # (I, H)
W1 = gate_up[intermediate_size:, :] # (I, H)
W2 = w2_stacked[i] # (H, I)
weight_dtype = W1.dtype
return lambda inp: F.linear(
act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3),
W2,
)

mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])]

# Apply routing based on the specified pattern
# INPUT-SIDE (default): silu(input * routing_weight) - routing affects activation
# OUTPUT-SIDE (alternative): silu(input) * routing_weight - routing scales output
return _template_moe(
x, selected_experts, routing_weights, mlps, apply_routing_on_input=apply_routing_on_input
)


@torch_moe_bmm.register_fake
def torch_moe_bmm_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked: torch.Tensor,
w2_stacked: torch.Tensor,
act_fn: str = "silu",
apply_routing_on_input: bool = True,
) -> torch.Tensor:
return torch.empty_like(x)


@torch.library.custom_op("auto_deploy::torch_moe_fused", mutates_args=())
def torch_fused_moe(
x: torch.Tensor,
Expand All @@ -159,23 +250,29 @@ def torch_fused_moe(
) -> torch.Tensor:
"""
A reference implementation of a fused MoE layer computation.

All weights are in TRT-LLM format (conversion from Llama4 happens during graph transformation).

Parameters:
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
S is the sequence length, and H is the hidden size.
selected_experts (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the
indices of the selected experts for each token.
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
routing weights for the selected experts.
w3_w1_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
containing the fused weights for w3 and w1 for each expert.
w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
containing the weights for w2 for each expert.
w3_w1_stacked_weight (torch.Tensor): Stacked gate/up weights in TRT-LLM format:
(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
w2_stacked_weight (torch.Tensor): Stacked down weights in TRT-LLM format:
(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
x_shape = x.shape
x = x.view(-1, x_shape[-1])
num_experts = w2_stacked_weight.shape[0]

# Standardized on TRT-LLM format (conversion happens during graph transformation)
# TRT-LLM format: gate_up is (2*I, H), down is (H, I)
intermediate_size = w3_w1_stacked_weight.shape[1] // 2
results = torch.zeros_like(x)

Expand All @@ -190,12 +287,7 @@ def torch_fused_moe(
w3 = stacked[:intermediate_size, :]
w1 = stacked[intermediate_size:, :]
w2 = w2_stacked_weight[expert_id]

# Compute expert output:
# expert_out = (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
out_w1 = expert_inputs @ w1.t()
out_w3 = expert_inputs @ w3.t()
expert_out = (F.silu(out_w1) * out_w3) @ w2.t()
expert_out = (F.silu(expert_inputs @ w1.t()) * (expert_inputs @ w3.t())) @ w2.t()

scaling = routing_weights[batch_idx, nth_expert].unsqueeze(-1)
results[batch_idx] += scaling * expert_out
Expand Down
Loading
Loading