From f8526aecc8106ed158a89b5f0963599a8d0b8bc6 Mon Sep 17 00:00:00 2001 From: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Thu, 27 Nov 2025 03:06:16 -0800 Subject: [PATCH 1/5] [#8733][Perf]: Add Llama4 MoE handling to AutoDeploy Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 2 + .../custom_ops/fused_moe/torch_moe.py | 120 ++- .../transform/library/fused_moe.py | 729 ++++++++++++++++-- .../auto_deploy/transform/library/sharding.py | 31 +- .../auto_deploy/utils/sharding_utils.py | 303 +++++++- .../library/test_ep_sharding.py | 42 + .../singlegpu/custom_ops/test_ad_moe_op.py | 97 ++- tests/unittest/_torch/helpers.py | 65 +- 8 files changed, 1293 insertions(+), 96 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 2f26b99b821..4ca6729949c 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -28,6 +28,8 @@ transforms: stage: pattern_matcher match_dense_moe_pattern: stage: pattern_matcher + match_bmm_moe_pattern: + stage: pattern_matcher match_repeat_kv: stage: pattern_matcher run_shape_prop: true diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index 63519cf5a58..fe0ebeeb665 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -31,8 +31,20 @@ def _template_moe( selected_experts: torch.Tensor, routing_weights: torch.Tensor, mlps: List[Callable[[torch.Tensor], torch.Tensor]], + apply_routing_on_input: bool = False, ) -> torch.Tensor: - """Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info.""" + """Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info. + + Args: + x: Input tensor + selected_experts: Expert indices for each token + routing_weights: Routing weights for each expert + mlps: List of MLP functions + apply_routing_on_input: If True, multiply routing weights with INPUT before MLP (BMM-based pattern). + This means: silu(input * routing_weight) + If False, multiply routing weights with OUTPUT after MLP (standard pattern). + This means: silu(input) * routing_weight + """ x_shape = x.shape hidden_dim = x_shape[-1] x = x.view(-1, hidden_dim) @@ -54,8 +66,18 @@ def _template_moe( if not tokens_for_this_expert.shape[0]: continue # input of shape [0, hidden_dim] breaks fp4 kernel - expert_out = mlps[expert_idx](tokens_for_this_expert) - current_hidden_states = expert_out * routing_weights[top_x, idx, None] + if apply_routing_on_input: + # INPUT-SIDE routing (BMM-based pattern): multiply routing weights with INPUT + # Result: silu(input * routing_weight) - routing affects activation + scaled_input = tokens_for_this_expert * routing_weights[top_x, idx, None] + expert_out = mlps[expert_idx](scaled_input) + current_hidden_states = expert_out + else: + # OUTPUT-SIDE routing (standard pattern): multiply routing weights with OUTPUT + # Result: silu(input) * routing_weight - routing applied after activation + expert_out = mlps[expert_idx](tokens_for_this_expert) + current_hidden_states = expert_out * routing_weights[top_x, idx, None] + final_hidden_states.index_add_( 0, top_x, current_hidden_states.to(final_hidden_states.dtype) ) @@ -93,7 +115,7 @@ def torch_moe( • mlp: W_down with shape (H, I) — down projection. w3_weight: List of per-expert weight tensors: - • gated_mlp: W3 with shape (I, H) — “up” (second) projection in gated MLP. + • gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP. • mlp: pass an empty list []; ignored. mlp_style: Selects the per-expert MLP computation: @@ -149,6 +171,75 @@ def torch_moe_fake( return torch.empty_like(x) +@torch.library.custom_op("auto_deploy::torch_moe_bmm", mutates_args=()) +def torch_moe_bmm( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked: torch.Tensor, + w2_stacked: torch.Tensor, + act_fn: str = "silu", + apply_routing_on_input: bool = True, +) -> torch.Tensor: + """MoE operator using stacked weights for expert computation. + + Only supports gated MLP (SwiGLU-style) architecture. + Uses stacked expert weights in TRT-LLM format (conversion from Llama4 happens during graph transformation). + + Weights are expected in TRT-LLM format: + - w3_w1_stacked: (NUM_EXPERTS, 2*INTERMEDIATE_SIZE, HIDDEN_SIZE) + - w2_stacked: (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) + + Args: + x: Input tensor + selected_experts: Expert indices for each token + routing_weights: Routing weights for each selected expert + w3_w1_stacked: Stacked gate/up weights (TRT-LLM format) + w2_stacked: Stacked down weights (TRT-LLM format) + act_fn: Activation function name (default: "silu") + apply_routing_on_input: If True, apply routing to INPUT (default for Llama4 pattern). + If False, apply routing to OUTPUT (alternative pattern). + """ + + act_fn = _resolve_activation(act_fn) + + # Standardized on TRT-LLM format (conversion happens during graph transformation) + # Only gated MLP supported: output = (act_fn(x @ W1) * (x @ W3)) @ W2 + def make_mlp(i: int): + gate_up = w3_w1_stacked[i] # (2*I, H) + intermediate_size = gate_up.shape[0] // 2 + W3 = gate_up[:intermediate_size, :] # (I, H) + W1 = gate_up[intermediate_size:, :] # (I, H) + W2 = w2_stacked[i] # (H, I) + weight_dtype = W1.dtype + return lambda inp: F.linear( + act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3), + W2, + ) + + mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])] + + # Apply routing based on the specified pattern + # INPUT-SIDE (default): silu(input * routing_weight) - routing affects activation + # OUTPUT-SIDE (alternative): silu(input) * routing_weight - routing scales output + return _template_moe( + x, selected_experts, routing_weights, mlps, apply_routing_on_input=apply_routing_on_input + ) + + +@torch_moe_bmm.register_fake +def torch_moe_bmm_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked: torch.Tensor, + w2_stacked: torch.Tensor, + act_fn: str = "silu", + apply_routing_on_input: bool = True, +) -> torch.Tensor: + return torch.empty_like(x) + + @torch.library.custom_op("auto_deploy::torch_moe_fused", mutates_args=()) def torch_fused_moe( x: torch.Tensor, @@ -159,6 +250,9 @@ def torch_fused_moe( ) -> torch.Tensor: """ A reference implementation of a fused MoE layer computation. + + All weights are in TRT-LLM format (conversion from Llama4 happens during graph transformation). + Parameters: x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size, S is the sequence length, and H is the hidden size. @@ -166,16 +260,19 @@ def torch_fused_moe( indices of the selected experts for each token. routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized routing weights for the selected experts. - w3_w1_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE) - containing the fused weights for w3 and w1 for each expert. - w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) - containing the weights for w2 for each expert. + w3_w1_stacked_weight (torch.Tensor): Stacked gate/up weights in TRT-LLM format: + (NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE) + w2_stacked_weight (torch.Tensor): Stacked down weights in TRT-LLM format: + (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) Returns: torch.Tensor: Output tensor with the same shape as the input x. """ x_shape = x.shape x = x.view(-1, x_shape[-1]) num_experts = w2_stacked_weight.shape[0] + + # Standardized on TRT-LLM format (conversion happens during graph transformation) + # TRT-LLM format: gate_up is (2*I, H), down is (H, I) intermediate_size = w3_w1_stacked_weight.shape[1] // 2 results = torch.zeros_like(x) @@ -190,12 +287,7 @@ def torch_fused_moe( w3 = stacked[:intermediate_size, :] w1 = stacked[intermediate_size:, :] w2 = w2_stacked_weight[expert_id] - - # Compute expert output: - # expert_out = (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t() - out_w1 = expert_inputs @ w1.t() - out_w3 = expert_inputs @ w3.t() - expert_out = (F.silu(out_w1) * out_w3) @ w2.t() + expert_out = (F.silu(expert_inputs @ w1.t()) * (expert_inputs @ w3.t())) @ w2.t() scaling = routing_weights[batch_idx, nth_expert].unsqueeze(-1) results[batch_idx] += scaling * expert_out diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index c7ea32fb0e2..48f9860eaf1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -18,11 +18,203 @@ ) +def _process_regular_moe_node( + gm: GraphModule, + graph: torch.fx.Graph, + node: Node, + replacement_op, + mlp_style_val: str, + act_fn_val: str, + fused_key_counter: int, +) -> None: + """Process a single torch_moe node with per-expert weight lists. + + Stacks weight parameters and creates a fused MoE node. + The kernel applies routing weights to the output. + """ + hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = extract_op_args( + node, + "x", + "selected_experts", + "routing_weights", + "w1_weight", + "w2_weight", + "w3_weight", + ) + + # Stack weights based on MLP style + if mlp_style_val == "gated_mlp": + # For gated MLP, concatenate w3 and w1 then stack across experts + fused_w_up_experts = torch.stack( + [ + torch.cat( + [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], + dim=-2, + ) + for w1_node, w3_node in zip(w1_list, w3_list) + ], + dim=0, + ) + new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}" + elif mlp_style_val == "mlp": + # For regular MLP, just stack w1 + fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) + new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" + else: + raise ValueError(f"Unknown mlp_style: {mlp_style_val}") + + # Stack w2/down weights + fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) + new_key_w_down = f"fused_moe_w2_stacked_{fused_key_counter}" + + # Register the stacked weights as parameters + param_w_up = torch.nn.Parameter(fused_w_up_experts) + gm.register_parameter(new_key_w_up, param_w_up) + w_up_arg = graph.get_attr(new_key_w_up) + + param_w_down = torch.nn.Parameter(fused_w_down_experts) + gm.register_parameter(new_key_w_down, param_w_down) + w_down_arg = graph.get_attr(new_key_w_down) + + # Create fused MoE node - kernel applies routing to output + with graph.inserting_before(node): + new_node = graph.call_function( + replacement_op, + args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg), + kwargs={ + "mlp_style": mlp_style_val, + "act_fn": act_fn_val, + }, + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def _process_llama4_stacked_moe_node( + gm: GraphModule, + graph: torch.fx.Graph, + node: Node, + replacement_op, + act_fn_val: str, + fused_key_counter: int, +) -> None: + """Process a single Llama4 MoE node with pre-stacked weight tensors (torch_moe_bmm). + + Only supports gated MLP (SwiGLU-style) architecture. + Converts Llama4 format weights to TRT-LLM format to standardize all downstream ops. + Applies routing weights to INPUT before the fused kernel to prevent double multiplication. + This is the Llama4 pattern where weights are already stacked across experts. + Result: silu(input * routing_weight) - routing affects activation. + """ + hidden_states, selected_experts, routing_weights, w3_w1_stacked, w2_stacked = extract_op_args( + node, + "x", + "selected_experts", + "routing_weights", + "w3_w1_stacked", + "w2_stacked", + ) + + # Convert Llama4 format to TRT-LLM format if needed + # This standardizes all downstream ops to only handle TRT-LLM format + if w3_w1_stacked.op == "get_attr" and w2_stacked.op == "get_attr": + gate_up_weight = gm.get_parameter(w3_w1_stacked.target) + down_weight = gm.get_parameter(w2_stacked.target) + + # Detect format: Llama4 has gate_up[0] == down[1] + is_llama4 = gate_up_weight.shape[1] == down_weight.shape[0] + + if is_llama4: + # Convert Llama4 (H, 2*I) -> TRT-LLM (2*I, H) + gate_up_trtllm = gate_up_weight.transpose(1, 2).contiguous() + down_trtllm = down_weight.transpose(1, 2).contiguous() + + # Register converted weights + new_key_w_up = f"llama4_to_trtllm_w3_w1_{fused_key_counter}" + new_key_w_down = f"llama4_to_trtllm_w2_{fused_key_counter}" + + gm.register_parameter(new_key_w_up, torch.nn.Parameter(gate_up_trtllm)) + gm.register_parameter(new_key_w_down, torch.nn.Parameter(down_trtllm)) + + w_up_arg = graph.get_attr(new_key_w_up) + w_down_arg = graph.get_attr(new_key_w_down) + else: + # Already TRT-LLM format, use directly + w_up_arg = w3_w1_stacked + w_down_arg = w2_stacked + else: + # Not get_attr nodes (might be intermediate ops), use directly + w_up_arg = w3_w1_stacked + w_down_arg = w2_stacked + + # Get weight dtype to ensure dtype consistency for Llama4 stacked tensors + # The fused kernel requires input and weights to have matching dtypes + weight_dtype = None + if w_up_arg.op == "get_attr": + try: + weight_tensor = gm.get_parameter(w_up_arg.target) + weight_dtype = weight_tensor.dtype + except (AttributeError, KeyError): + pass + + # Llama4 INPUT-SIDE routing: apply routing to INPUT before kernel + # Cast BOTH input and routing_weights to weight dtype if needed + # Critical: BFloat16 * Float32 → Float32 (type promotion) so we cast both to same dtype + with graph.inserting_before(node): + input_to_scale = hidden_states + routing_to_scale = routing_weights + + if weight_dtype is not None and weight_dtype != torch.float32: + input_to_scale = graph.call_function( + torch.ops.aten._to_copy.default, + args=(hidden_states,), + kwargs={"dtype": weight_dtype}, + ) + routing_to_scale = graph.call_function( + torch.ops.aten._to_copy.default, + args=(routing_weights,), + kwargs={"dtype": weight_dtype}, + ) + + # Scale input: hidden_states = hidden_states * routing_weights (both same dtype now) + scaled_input = graph.call_function( + torch.ops.aten.mul.Tensor, + args=(input_to_scale, routing_to_scale), + ) + + # Pass ones to kernel to prevent it from multiplying routing again (already applied) + ones_node = graph.call_function( + torch.ops.aten.ones_like.default, + args=(routing_weights,), + ) + + new_node = graph.call_function( + replacement_op, + args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg), + kwargs={ + "act_fn": act_fn_val, + }, + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: + """Replace torch MoE ops with fused backend-specific implementations. + + Handles both: + - Standard MoE (per-expert weight lists): torch_moe + - Llama4 MoE (pre-stacked weight tensors): torch_moe_bmm + + For Llama4 stacked tensors, applies routing weights to input before the fused kernel. + """ fused_key_counter = 0 graph = gm.graph backend = backend.lower() + # Map backend to fused MoE op (handles both standard and Llama4 stacked tensor patterns) replacement_op = { "auto": torch.ops.auto_deploy.trtllm_moe_fused, "trtllm": torch.ops.auto_deploy.trtllm_moe_fused, @@ -30,71 +222,31 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t }[backend] for node in graph.nodes: - if not is_op(node, torch.ops.auto_deploy.torch_moe): + if not ( + is_op(node, torch.ops.auto_deploy.torch_moe) + or is_op(node, torch.ops.auto_deploy.torch_moe_bmm) + ): continue - (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") - assert backend != "triton" or mlp_style_val == "mlp", ( - "Triton backend only supports mlp style." - ) + is_llama4_stacked_moe = is_op(node, torch.ops.auto_deploy.torch_moe_bmm) - hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = ( - extract_op_args( - node, - "x", - "selected_experts", - "routing_weights", - "w1_weight", - "w2_weight", - "w3_weight", + if is_llama4_stacked_moe: + # torch_moe_bmm: only supports gated MLP, no mlp_style parameter + (act_fn_val,) = extract_op_args(node, "act_fn") + _process_llama4_stacked_moe_node( + gm, graph, node, replacement_op, act_fn_val, fused_key_counter ) - ) - if mlp_style_val == "gated_mlp": - fused_w_up_experts = torch.stack( - [ - torch.cat( - [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2 - ) - for w1_node, w3_node in zip(w1_list, w3_list) - ], - dim=0, - ) - new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}" - - elif mlp_style_val == "mlp": - fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) - new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" else: - raise ValueError(f"Unknown mlp_style: {mlp_style_val}") - - fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) - - new_key_w_down = f"fused_moe_w2_stacked_{fused_key_counter}" - fused_key_counter += 1 - param_w_up = torch.nn.Parameter(fused_w_up_experts) - param_w_down = torch.nn.Parameter(fused_w_down_experts) - gm.register_parameter(new_key_w_up, param_w_up) - gm.register_parameter(new_key_w_down, param_w_down) - - with graph.inserting_before(node): - new_node = graph.call_function( - # TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models - replacement_op, - args=( - hidden_states, - selected_experts, - routing_weights, - graph.get_attr(new_key_w_up), - graph.get_attr(new_key_w_down), - ), - kwargs={ - "mlp_style": mlp_style_val, - "act_fn": act_fn_val, - }, + # torch_moe: supports both gated_mlp and mlp styles + (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") + assert backend != "triton" or mlp_style_val == "mlp", ( + "Triton backend only supports mlp style." + ) + _process_regular_moe_node( + gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter ) - node.replace_all_uses_with(new_node) - graph.erase_node(node) + fused_key_counter += 1 # Delete the unstacked weights immediately to save GPU memory # This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory @@ -585,6 +737,467 @@ def scale_keys(self) -> List[str]: return ["input_scale", "weight_scale", "alpha"] +class MatchBmmMoePatternConfig(TransformConfig): + """Configuration for MatchBmmMoePattern transform.""" + + pass + + +@TransformRegistry.register("match_bmm_moe_pattern") +class MatchBmmMoePattern(BaseTransform): + """Match and fuse Llama4 MoE pattern with pre-stacked weight tensors. + + This pattern uses batch matrix multiply (BMM) operations for parallel expert computation + with weights already stacked across the expert dimension. + + Only matches patterns where topk uses k=1 (single expert per token). + """ + + config: MatchBmmMoePatternConfig + + @classmethod + def get_config_class(cls): + return MatchBmmMoePatternConfig + + @staticmethod + def _find_gate_up_bmm(final_bmm: Node) -> Optional[Tuple[Node, Node]]: + """Find the MoE gate_up BMM and chunk node from the final BMM. + + BMM MoE pattern traces back: final_bmm <- mul(up, silu(gate)) <- chunk <- first_bmm (gate_up) + + Returns: + Tuple of (first_bmm, gate_up_weight) or None if not found + """ + # Input to final bmm should be mul(up, silu(gate)) + mul_node = final_bmm.args[0] + if not isinstance(mul_node, Node) or not is_op(mul_node, torch.ops.aten.mul): + return None + + if not mul_node.args or len(mul_node.args) < 2: + return None + + # Find silu node (one of the mul inputs) + arg_a, arg_b = mul_node.args[:2] + silu_node = ( + arg_a + if is_op(arg_a, torch.ops.aten.silu) + else arg_b + if is_op(arg_b, torch.ops.aten.silu) + else None + ) + if silu_node is None: + return None + + up_node = arg_b if arg_a is silu_node else arg_a + if not isinstance(up_node, Node): + return None + + # silu input should be gate from chunk + if not silu_node.args or not isinstance(silu_node.args[0], Node): + return None + gate_node = silu_node.args[0] + + # Both gate and up come from chunk (getitem nodes) + if gate_node.op != "call_function" or up_node.op != "call_function": + return None + + # Find the chunk node + chunk_node = None + if hasattr(gate_node, "args") and gate_node.args: + potential_chunk = gate_node.args[0] + if isinstance(potential_chunk, Node) and is_op(potential_chunk, torch.ops.aten.chunk): + chunk_node = potential_chunk + + if chunk_node is None or not chunk_node.args or chunk_node.args[1] != 2: + return None + + # chunk input is the first batched BMM for Llama4 (gate_up_proj) + first_bmm = chunk_node.args[0] + if not isinstance(first_bmm, Node) or not is_op(first_bmm, torch.ops.aten.bmm): + return None + + if not first_bmm.args or len(first_bmm.args) < 2: + return None + + # Llama4: gate_up_weight is pre-stacked [num_experts, hidden, 2*intermediate] + gate_up_weight = first_bmm.args[1] + if not isinstance(gate_up_weight, Node) or gate_up_weight.op != "get_attr": + return None + + return (first_bmm, gate_up_weight) + + @staticmethod + def _find_input_and_routing(batched_input: Node) -> Optional[Tuple[Node, Node]]: + """Find the input hidden states and routing weights from batched input. + + BMM MoE pattern traces back: batched_input <- mul(repeat(input), routing) <- repeat <- input + + Only matches patterns where topk uses k=1 (single expert per token). + + Returns: + Tuple of (input_hidden_states, topk_node) or None if not found + """ + # batched_input comes from mul(repeated_input, routing_weights) + if not batched_input.args or not isinstance(batched_input.args[0], Node): + return None + + mul_routing = batched_input.args[0] + if ( + not is_op(mul_routing, torch.ops.aten.mul) + or not mul_routing.args + or len(mul_routing.args) < 2 + ): + return None + + # One arg is repeat (input), other is routing weights + repeat_node = None + routing_weight_node = None + for arg in mul_routing.args[:2]: + if isinstance(arg, Node): + if is_op(arg, torch.ops.aten.repeat): + repeat_node = arg + else: + routing_weight_node = arg + + if not repeat_node or not routing_weight_node: + return None + + # Get original input from repeat + if not repeat_node.args or not isinstance(repeat_node.args[0], Node): + return None + input_hidden_states = repeat_node.args[0] + + # Trace back from routing_weight to find topk + try: + topk_node, _ = bfs( + routing_weight_node, + lambda n: is_op(n, torch.ops.aten.topk), + attr_next="all_input_nodes", + ) + router_logits = topk_node.args[0] if topk_node.args else None + if not router_logits: + return None + + # Verify topk is using k=1 (only match single-expert-per-token routing) + if len(topk_node.args) < 2: + return None + k_value = topk_node.args[1] + if k_value != 1: + return None + + except RuntimeError: + return None + + return (input_hidden_states, topk_node) + + @staticmethod + def _find_output_and_routing_flavor(final_bmm: Node) -> Optional[Tuple[Node, bool]]: + """Find the output node and detect routing application method. + + Llama4 stacked MoE pattern traces forward: final_bmm -> view -> reshape -> sum [-> mul?] + + Returns: + Tuple of (output_node, apply_routing_on_input) or None if not found + apply_routing_on_input is True if routing is applied to input, False if applied to output + """ + # Llama4 pattern: bmm -> view([-1, hidden]) -> reshape([num_experts, -1, hidden]) -> sum(dim=0) + output_view = None + for user in final_bmm.users: + if is_op(user, torch.ops.aten.view): + output_view = user + break + + if not output_view: + return None + + # Find reshape after view + reshape_node = None + for user in output_view.users: + if is_op(user, torch.ops.aten.reshape): + reshape_node = user + break + + if not reshape_node: + return None + + # Find sum after reshape + sum_node = None + for user in reshape_node.users: + if is_op(user, torch.ops.aten.sum): + sum_node = user + break + + if not sum_node: + return None + + # Detect routing application method: check if routing is applied after sum (OUTPUT) + apply_routing_on_input = True # Default for Llama4 (routing already applied before BMM) + output_node = sum_node + + for user in sum_node.users: + if is_op(user, torch.ops.aten.mul): + # Found multiplication after sum - routing is applied to OUTPUT + apply_routing_on_input = False + output_node = user + break + + return (output_node, apply_routing_on_input) + + @staticmethod + def _match_bmm_moe_pattern( + start_boundary: Node, + end_boundary: Node, + ): + """ + Match the BMM MoE pattern (ONE pattern per layer, not per expert). + + This BMM MoE pattern uses batch matrix multiply (BMM) operations for parallel expert + computation with pre-stacked weight tensors. + + Only matches patterns where topk uses k=1 (single expert per token). + + Supports TWO routing flavors: + + 1. INPUT-SIDE routing (most common): + - Pattern: mul(input, routing) -> bmm -> silu -> bmm -> sum + - Result: silu(input * routing_weight) - routing affects activation + - Routing multiplication happens BEFORE BMM operations + + 2. OUTPUT-SIDE routing (alternative): + - Pattern: bmm -> silu -> bmm -> sum -> mul(output, routing) + - Result: silu(input) * routing_weight) - routing scales output + - Routing multiplication happens AFTER sum + + The function auto-detects which flavor is present and returns metadata. + + The BMM MoE pattern corresponds to: + repeated_input = repeat(input, [num_experts, 1]) + routing_weights = reshape(transpose(sigmoid(scatter(topk(router_logits))))) + routed_input = mul(repeated_input, routing_weights) # <-- INPUT-SIDE MULTIPLICATION + batched_input = view(routed_input, [num_experts, -1, hidden]) + gate_up = bmm(batched_input, gate_up_proj) # gate_up_proj: pre-stacked [num_experts, hidden, 2*inter] + gate, up = chunk(gate_up, 2) + output = bmm(up * silu(gate), down_proj) # down_proj: pre-stacked [num_experts, inter, hidden] + final_output = view(output, [-1, hidden]) + + Returns: + List of dicts, one per MoE layer found: + { + "input": Node, # Unmultiplied input tensor + "router_logits": Node, # Router logits tensor + "gate_up_weight": Node, # Stacked gate+up weights + "down_weight": Node, # Stacked down weights + "output": Node, # Output node to replace (sum or mul) + "topk": Node, # TopK node for routing + "apply_routing_on_input": bool, # True if routing on input, False if on output + } + """ + moe_layers = [] + + nodes = list(start_boundary.graph.nodes) + region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)] + + for node in region_nodes: + # Look for the final bmm (down_proj) - this is the BATCHED bmm + if not is_op(node, torch.ops.aten.bmm): + continue + + final_bmm = node + if not final_bmm.args or len(final_bmm.args) < 2: + continue + + # Step 1: Get down_proj weight + down_weight = final_bmm.args[1] + if not isinstance(down_weight, Node) or down_weight.op != "get_attr": + continue + + # Step 2: Find the first BMM (gate_up) by tracing back through chunk and mul + result = MatchBmmMoePattern._find_gate_up_bmm(final_bmm) + if result is None: + continue + first_bmm, gate_up_weight = result + + # Step 3: Get batched input and trace back to original input and routing + batched_input = first_bmm.args[0] + if not isinstance(batched_input, Node) or not is_op(batched_input, torch.ops.aten.view): + continue + + result = MatchBmmMoePattern._find_input_and_routing(batched_input) + if result is None: + continue + input_hidden_states, topk_node = result + + # Get router_logits for metadata + router_logits = topk_node.args[0] if topk_node.args else None + if not router_logits: + continue + + # Step 4: Find output node and detect routing application method + result = MatchBmmMoePattern._find_output_and_routing_flavor(final_bmm) + if result is None: + continue + output_node, apply_routing_on_input = result + + # Step 5: Add the matched layer + moe_layers.append( + { + "input": input_hidden_states, + "router_logits": router_logits, + "gate_up_weight": gate_up_weight, + "down_weight": down_weight, + "output": output_node, + "topk": topk_node, + "apply_routing_on_input": apply_routing_on_input, + } + ) + + return moe_layers + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph. + boundary_nodes = identify_regions_between_residuals(gm) + + num_moe_patterns = 0 + + for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): + # Step 1: Match BMM MoE patterns (one pattern per MoE layer) + moe_layers = self._match_bmm_moe_pattern(start_boundary, end_boundary) + + if not moe_layers: + continue + + # Process each MoE layer + for layer_info in moe_layers: + input_hidden_states = layer_info["input"] + gate_up_weight = layer_info["gate_up_weight"] + down_weight = layer_info["down_weight"] + output_node = layer_info["output"] + topk_node = layer_info["topk"] + # Get routing application method from pattern matcher + # Default to True (apply on input) which is the common Llama4 pattern + input_routing = layer_info.get("apply_routing_on_input", True) + + # Step 2: Extract routing information + # selected_experts: topk indices [tokens, top_k] + # routing_weights: topk values [tokens, top_k] + selected_experts = None + routing_weights_node = None + + for user in topk_node.users: + if ( + user.op == "call_function" + and hasattr(user.target, "__name__") + and user.target.__name__ == "getitem" + ): + if len(user.args) >= 2: + if user.args[1] == 1: # indices + selected_experts = user + elif user.args[1] == 0: # values + # For the fused MoE op, we use topk values directly + # (shape: tokens x top_k), NOT the scattered version + routing_weights_node = user + + if not selected_experts: + continue + + if not routing_weights_node: + continue + + # Find scatter and sigmoid nodes - we need the sigmoid output for routing weights! + # The original pattern: sigmoid(scatter(topk_values)) normalizes routing to [0,1] + # The fused op must use the same normalized values, not raw topk logits. + scatter_node = None + sigmoid_node = None + + # Find scatter operation + for user in routing_weights_node.users: + if is_op(user, {torch.ops.aten.scatter, torch.ops.aten.scatter_}): + scatter_node = user + # Check if sigmoid exists after scatter (may have to.dtype in between) + current = user + for _ in range(2): # Allow up to 2 hops + for next_user in current.users: + if is_op(next_user, torch.ops.aten.sigmoid): + sigmoid_node = next_user + break + elif is_op(next_user, torch.ops.aten.to): + current = next_user + break + if sigmoid_node: + break + break + + if not scatter_node: + continue + + if not sigmoid_node: + continue + + # Extract normalized routing weights from the sigmoid output + # The sigmoid output has shape [tokens, num_experts] (scattered) + # We need to gather it back to [tokens, top_k] using selected_experts indices + # This gives us sigmoid(scatter(topk_values)) in the compact [tokens, top_k] format + graph = gm.graph + with graph.inserting_after(sigmoid_node): + # Create gather operation: routing_weights_normalized = sigmoid_output.gather(1, selected_experts) + routing_weights_normalized = graph.call_function( + torch.ops.aten.gather, + args=(sigmoid_node, 1, selected_experts), + ) + + # Use the normalized routing weights instead of raw topk values + routing_weights_node = routing_weights_normalized + + # Step 4: Apply MoE fusion + # If input_routing is True: kernel applies routing to input + # If input_routing is False: kernel applies routing to output + apply_routing_on_input = input_routing + + with graph.inserting_before(output_node): + fused_moe_node = graph.call_function( + torch.ops.auto_deploy.torch_moe_bmm, + args=( + input_hidden_states, + selected_experts, + routing_weights_node, + gate_up_weight, + down_weight, + ), + kwargs={ + "apply_routing_on_input": apply_routing_on_input, + }, + ) + + # Replace the output node with fused MoE + output_node.replace_all_uses_with(fused_moe_node) + graph.erase_node(output_node) + + # Clean up dead nodes + gm.graph.eliminate_dead_code() + + # Clean up dead inplace nodes in the region + while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): + gm.graph.eliminate_dead_code() + + # Delete unused submodules/parameters + gm.delete_all_unused_submodules() + + num_moe_patterns += 1 + + info = TransformInfo( + skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False + ) + return gm, info + + def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: """ Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters. diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 1ed5652c628..486c17b1ca1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -39,6 +39,7 @@ subgraph, ) from ...utils.sharding_utils import ( + BmmEPShardingInfo, BMMShardingInfo, EPShardingInfo, LayerType, @@ -137,6 +138,9 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: for ep_transform in transforms.ep_transforms: if check_and_apply(ep_transform): num_matches += 1 + for stacked_ep_transform in transforms.stacked_ep_transforms: + if check_and_apply(stacked_ep_transform): + num_matches += 1 # post-sharding cleanup transformations for update_transform in transforms.parameter_update_transforms: @@ -1058,20 +1062,31 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingTransformContainer node, ( torch.ops.auto_deploy.torch_moe, + torch.ops.auto_deploy.torch_moe_bmm, torch.ops.auto_deploy.torch_quant_fp8_moe, torch.ops.auto_deploy.torch_quant_nvfp4_moe, torch.ops.auto_deploy.triton_mxfp4_moe, ), ): continue - if sharding_config.add( - EPShardingInfo.from_node( - node, - rank=rank, - world_size=world_size, - ) - ): - num_moe_patterns += 1 + if is_op(node, torch.ops.auto_deploy.torch_moe_bmm): + if sharding_config.add( + BmmEPShardingInfo.from_node( + node, + rank=rank, + world_size=world_size, + ) + ): + num_moe_patterns += 1 + else: + if sharding_config.add( + EPShardingInfo.from_node( + node, + rank=rank, + world_size=world_size, + ) + ): + num_moe_patterns += 1 ad_logger.info(f"Found {num_moe_patterns} MoE patterns") diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 08441632c39..affb138fbb7 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -6,10 +6,18 @@ from abc import ABC, abstractmethod from enum import Enum, IntEnum from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple - -if TYPE_CHECKING: - from ..transform.library.sharding import ShardingTransformConfig +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) import torch import torch.nn as nn @@ -32,6 +40,9 @@ modelopt_fp4_scale_to_cutlass_fp4_scale, ) +if TYPE_CHECKING: + from ..transform.library.sharding import ShardingTransformConfig + def validate_allreduce_strategy(v): """Convert string names like 'AUTO' to AllReduceStrategy enum. @@ -1082,14 +1093,268 @@ def get_partition(lst, world_size, rank): dist_node.replace_input_with(dist_node, node) -def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node: - """Return tensor_node[lo:hi, ...] via aten.slice along dim 0.""" - with gm.graph.inserting_after(tensor_node): - # aten.slice.Tensor(self, dim, start, end, step) - return gm.graph.call_function( - torch.ops.aten.slice.Tensor, - args=(tensor_node, 0, lo, hi, 1), +def _slice_expert_dim( + gm: GraphModule, tensor_node_or_tensor: Union[Node, torch.Tensor], lo: int, hi: int +) -> Union[Node, torch.Tensor]: + """Slice expert weights along dim 0 and register load hook. + Args: + gm: The graph module + tensor_node_or_tensor: Either a Node (from FX graph) or a Tensor + lo: Start index for slicing + hi: End index for slicing + + Returns: + Node or Tensor depending on input type + """ + # Handle raw tensor case + if isinstance(tensor_node_or_tensor, torch.Tensor): + return tensor_node_or_tensor[lo:hi] + + # Handle Node case + tensor_node = tensor_node_or_tensor + + if tensor_node.op != "get_attr": + # If not a parameter node, just add a runtime slice node + with gm.graph.inserting_after(tensor_node): + return gm.graph.call_function( + torch.ops.aten.slice.Tensor, + args=(tensor_node, 0, lo, hi, 1), + ) + + # Get the parameter + param_key = str(tensor_node.target) + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) if modname else gm + full_param = getattr(submod, param_name) + + # Slice the parameter + sliced_param = full_param[lo:hi].detach().clone() + sliced_shape = sliced_param.shape + + # Define slice function for load hook + def slice_expert_tensor(t: torch.Tensor) -> torch.Tensor: + return t[lo:hi] + + # Register load hook to slice during checkpoint loading + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=slice_expert_tensor, + param_key=param_key, + param_shape=sliced_shape, + ) + ) + + # Replace the parameter with the sliced version + new_param = nn.Parameter(sliced_param, requires_grad=full_param.requires_grad) + setattr(submod, param_name, new_param) + + # Return the same node (it now points to the sliced parameter) + return tensor_node + + +def _slice_and_transpose_expert_dim( + gm: GraphModule, + tensor_node_or_tensor: Union[Node, torch.Tensor], + lo: int, + hi: int, + transpose_dim1: int, + transpose_dim2: int, + swap_gate_up: bool = False, +) -> Union[Node, torch.Tensor]: + """Slice expert weights along dim 0 and transpose, with load hook registration. + + This is specifically for converting Llama4 MoE weight format to TRT-LLM format. + Llama4 stores weights as (E, H, 2*I) and (E, I, H), but TRT-LLM expects (E, 2*I, H) and (E, H, I). + For gate_up weights, Llama4 has [W1, W3] but TRT-LLM expects [W3, W1], so we swap them. + + Args: + gm: The graph module + tensor_node_or_tensor: Either a Node (from FX graph) or a Tensor + lo: Start index for slicing along expert dimension + hi: End index for slicing along expert dimension + transpose_dim1: First dimension to transpose + transpose_dim2: Second dimension to transpose + swap_gate_up: If True, swap W1 and W3 in the fused gate_up dimension + + Returns: + Node or Tensor depending on input type, sliced and transposed + """ + # Handle raw tensor case + if isinstance(tensor_node_or_tensor, torch.Tensor): + sliced = tensor_node_or_tensor[lo:hi] + if swap_gate_up: + # For gate_up: (E_shard, H, 2*I) -> swap last dim -> transpose + intermediate_size = sliced.shape[2] // 2 + w1 = sliced[:, :, :intermediate_size] + w3 = sliced[:, :, intermediate_size:] + sliced = torch.cat([w3, w1], dim=2) + return sliced.transpose(transpose_dim1, transpose_dim2).contiguous() + + # Handle Node case + tensor_node = tensor_node_or_tensor + + if tensor_node.op != "get_attr": + # If not a parameter node, add runtime operations + with gm.graph.inserting_after(tensor_node): + sliced_node = gm.graph.call_function( + torch.ops.aten.slice.Tensor, + args=(tensor_node, 0, lo, hi, 1), + ) + if swap_gate_up: + # Would need complex graph operations, skip for now + ad_logger.warning(f"Cannot swap gate_up for non-parameter node {tensor_node}") + return gm.graph.call_function( + torch.ops.aten.transpose.int, + args=(sliced_node, transpose_dim1, transpose_dim2), + ) + + # Get the parameter + param_key = str(tensor_node.target) + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) if modname else gm + full_param = getattr(submod, param_name) + + # Slice the parameter + sliced_param = full_param[lo:hi].detach().clone() + + # Swap W1 and W3 if needed (for gate_up weights) + if swap_gate_up: + # Llama4: (E, H, 2*I) with [W1, W3], TRT-LLM wants [W3, W1] + intermediate_size = sliced_param.shape[2] // 2 + w1 = sliced_param[:, :, :intermediate_size] + w3 = sliced_param[:, :, intermediate_size:] + sliced_param = torch.cat([w3, w1], dim=2) + + # Transpose the parameter + transposed_param = sliced_param.transpose(transpose_dim1, transpose_dim2).contiguous() + transposed_shape = transposed_param.shape + + # Define slice+swap+transpose function for load hook + def slice_swap_and_transpose(t: torch.Tensor) -> torch.Tensor: + t_sliced = t[lo:hi] + if swap_gate_up: + intermediate_size = t_sliced.shape[2] // 2 + w1 = t_sliced[:, :, :intermediate_size] + w3 = t_sliced[:, :, intermediate_size:] + t_sliced = torch.cat([w3, w1], dim=2) + return t_sliced.transpose(transpose_dim1, transpose_dim2).contiguous() + + # Register load hook to slice+swap+transpose during checkpoint loading + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=slice_swap_and_transpose, + param_key=param_key, + param_shape=transposed_shape, + ) + ) + + # Replace the parameter with the sliced+transposed version + new_param = nn.Parameter(transposed_param, requires_grad=full_param.requires_grad) + setattr(submod, param_name, new_param) + + # Return the same node (it now points to the transformed parameter) + return tensor_node + + +def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int: + """Helper to get the first dimension size of an argument (Node or Tensor).""" + if isinstance(arg, torch.Tensor): + return arg.shape[0] + if isinstance(arg, Node): + if arg.op == "get_attr": + # Traverse attributes to find the tensor + obj = gm + for atom in arg.target.split("."): + obj = getattr(obj, atom) + return obj.shape[0] + if "val" in arg.meta: + return arg.meta["val"].shape[0] + raise ValueError(f"Cannot determine shape[0] for {arg}") + + +def _insert_sharded_moe_bmm( + gm: GraphModule, + node: Node, + rank: int, + world_size: int, + allreduce_strategy: AllReduceStrategy, + scale_names: Sequence[str] = (), +): + """Update the torch_moe_bmm node with sliced weight tensors, + sharded `selected_experts` and `final_scales(router_logics)`. + Add an all_reduce node after the moe node. + Similar to _insert_sharded_moe but for BMM-based (stacked) tensors instead of lists. + + NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. + """ + if allreduce_strategy is None: + raise ValueError(f"allreduce_strategy must be set for MoE BMM sharding on node {node.name}") + scale_names = list(scale_names) + + # Get num_experts from the first weight tensor (w3_w1_stacked) + w3_w1_tensor_arg = node.args[3] + num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_arg) + + args = list(node.args) + + # -- Handle selected_experts and final_scales sharding -- + selected_experts = args[1] + final_scales = args[2] + + experts_per_rank = num_experts // world_size + + with gm.graph.inserting_before(node): + lower = experts_per_rank * rank + # selected_experts_local = selected_experts - low + selected_experts_local = gm.graph.create_node( + "call_function", operator.sub, args=(selected_experts, lower), kwargs={} + ) + + # For num_experts % world_size != 0 case, + # assign the last (num_experts % world_size) experts to the last rank + div_node = gm.graph.create_node( + "call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={} + ) + + comp_op = torch.ge if rank == world_size - 1 else torch.eq + rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={}) + + # final_scales_local = final_scales * rank_mask + final_scales_local = gm.graph.create_node( + "call_function", operator.mul, args=(final_scales, rank_mask), kwargs={} + ) + + # -- Slice expert weights along dim 0 (expert dimension) -- + local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank) + + # -- Update args -- + args[1] = selected_experts_local + args[2] = final_scales_local + # Slice, swap [W1,W3]->[W3,W1], and transpose for TRT-LLM format: Llama4 (E, H, 2*I) -> TRT-LLM (E, 2*I, H) + args[3] = _slice_and_transpose_expert_dim( + gm, args[3], local_lo, local_hi, 1, 2, swap_gate_up=True + ) # w3_w1_stacked + # Slice and transpose for TRT-LLM format: Llama4 (E, I, H) -> TRT-LLM (E, H, I) + args[4] = _slice_and_transpose_expert_dim( + gm, args[4], local_lo, local_hi, 1, 2, swap_gate_up=False + ) # w2_stacked + + ad_logger.debug( + f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." + ) + + node.args = tuple(args) + + # -- add an all_reduce node -- + with gm.graph.inserting_after(node): + dist_node = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_reduce.default, + args=(node, allreduce_strategy.name), ) + node.replace_all_uses_with(dist_node) + dist_node.replace_input_with(dist_node, node) def _split_range_last_remainder(n: int, world_size: int, rank: int): @@ -1240,10 +1505,24 @@ def apply(self, gm: GraphModule, node: Node) -> None: ) +class BmmEPShardingInfo(EPShardingInfo): + """EP sharding behavior for torch_moe_bmm (BMM-based MoE with stacked weights).""" + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + if not is_op(node, torch.ops.auto_deploy.torch_moe_bmm): + ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") + return False + return True + + def apply(self, gm: GraphModule, node: Node) -> None: + _insert_sharded_moe_bmm(gm, node, self.rank, self.world_size, self.allreduce_strategy) + + EP_SHARDING_RULES = [ (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_fp8_moe), FP8EPShardingInfo), (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_moe), NVFP4EPShardingInfo), (lambda n: is_op(n, torch.ops.auto_deploy.torch_moe), EPShardingInfo), + (lambda n: is_op(n, torch.ops.auto_deploy.torch_moe_bmm), BmmEPShardingInfo), (lambda n: is_op(n, torch.ops.auto_deploy.triton_mxfp4_moe), MXFP4EPShardingInfo), ] @@ -1301,6 +1580,7 @@ class ShardingTransformContainer(BaseModel): parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + stacked_ep_transforms: List[BmmEPShardingInfo] = Field(default_factory=list) @field_validator("allreduce_strategy", mode="before") @classmethod @@ -1315,6 +1595,7 @@ def __init__(self, **kwargs): BMMShardingInfo: self.bmm_transforms, EPShardingInfo: self.ep_transforms, ParameterUpdateInfo: self.parameter_update_transforms, + BmmEPShardingInfo: self.stacked_ep_transforms, } def init_params( diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 6b7c79de051..171487fff7f 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -157,3 +157,45 @@ def test_sharding_pattern_detection(world_size: int, num_experts: int): rank=0, world_size=world_size, ) + + +def test_llama4_stacked_moe_pattern_detection(): + """Minimal test: verify torch_moe_bmm is detected for EP sharding.""" + # Create a simple graph with torch_moe_bmm node + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + graph = gm.graph + + with graph.inserting_after(): + x = graph.placeholder("x") + selected_experts = graph.placeholder("selected_experts") + routing_weights = graph.placeholder("routing_weights") + w3_w1 = graph.placeholder("w3_w1_stacked") + w2 = graph.placeholder("w2_stacked") + + moe_node = graph.call_function( + torch.ops.auto_deploy.torch_moe_bmm, + args=(x, selected_experts, routing_weights, w3_w1, w2), + ) + graph.output(moe_node) + + graph.lint() + gm.recompile() + + # Run pattern detection for EP + optimizer = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "use_sharding_from_factory": False, + }, + }, + ) + optimizer.shared_config.local_rank = 0 + optimizer.shared_config.world_size = 2 + _ = optimizer(None, gm) + + # Verify torch_moe_bmm is detected for EP sharding + detected = optimizer.shared_config.sharding_config.stacked_ep_transforms + assert len(detected) == 1, f"Expected 1 EP transform, got {len(detected)}" + assert detected[0].target_node == moe_node.name diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index da99e08bbe2..9d398588e10 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -1,7 +1,7 @@ import pytest import torch import torch.nn.functional as F -from _torch.helpers import reference_moe_torch +from _torch.helpers import reference_bmm_moe_torch, reference_moe_torch from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 @@ -62,6 +62,50 @@ def setup_moe_test(dtype, num_experts): ) +def setup_bmm_moe_test(dtype, num_experts): + """Setup for stacked MoE (torch_moe_bmm) with topk=1 in TRT-LLM format.""" + SEQ_LEN = 8 + HIDDEN_SIZE = 64 + INTERMEDIATE_SIZE = 32 + NUM_EXPERTS = num_experts + TOP_K = 1 # Llama4 stacked pattern requires topk=1 + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + x = torch.rand(SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 + + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=torch.float32).cuda() + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + final_scales, selected_experts = torch.topk(routing_weights, TOP_K, dim=-1) + final_scales = final_scales / final_scales.sum(dim=-1, keepdim=True) + final_scales = final_scales.to(x.dtype) + + # TRT-LLM format: gate_up is (2*I, H), down is (H, I) + w3_w1_stacked_weight = torch.empty( + (NUM_EXPERTS, INTERMEDIATE_SIZE * 2, HIDDEN_SIZE), dtype=dtype + ).cuda() + w2_stacked_weight = torch.empty( + (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype + ).cuda() + + for expert_id in range(NUM_EXPERTS): + w1 = torch.rand(INTERMEDIATE_SIZE, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 + w2 = torch.rand(HIDDEN_SIZE, INTERMEDIATE_SIZE, dtype=dtype).cuda() * 0.1 + w3 = torch.rand(INTERMEDIATE_SIZE, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 + + # TRT-LLM format: concat w3 and w1 along intermediate dim + w3_w1_stacked_weight.data[expert_id].copy_(torch.cat([w3, w1], dim=0)) # (2*I, H) + w2_stacked_weight.data[expert_id].copy_(w2) # (H, I) + + return ( + x, + selected_experts, + final_scales, + w3_w1_stacked_weight, + w2_stacked_weight, + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_moe_op_run(dtype): num_experts = 3 @@ -109,6 +153,57 @@ def test_moe_op_run(dtype): torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_bmm_based_moe_op_run(dtype): + num_experts = 3 + ( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) = setup_bmm_moe_test(dtype, num_experts) + + with torch.inference_mode(): + x = final_scales * x + selected_experts = torch.ones_like(selected_experts) + output_torch_moe = torch.ops.auto_deploy.torch_moe_bmm( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) + output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) + output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) + ref_output = reference_bmm_moe_torch( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + apply_routing_on_input=True, + ) + + torch.cuda.synchronize() + torch.testing.assert_close(output_trt_fused_moe, output_torch_fused_moe, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(output_trt_fused_moe, ref_output, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(output_torch_fused_moe, ref_output, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support") def test_fp8_moe_op_run(dtype): diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 163309cb1f4..e5dd67effa5 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -93,9 +93,12 @@ def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype): return atol -def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, - final_scales: torch.Tensor, num_experts: int, - weights: Dict[str, torch.Tensor]) -> torch.Tensor: +def reference_moe_torch(x: torch.Tensor, + selected_experts: torch.Tensor, + final_scales: torch.Tensor, + num_experts: int, + weights: Dict[str, torch.Tensor], + apply_routing_on_input: bool = False) -> torch.Tensor: # cast back to the input dtype results = torch.zeros_like(x) @@ -106,9 +109,63 @@ def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, w2_weight = weights[f"{expert_id}.w2.weight"] w3_weight = weights[f"{expert_id}.w3.weight"] expert_inputs = x[batch_idx] + if apply_routing_on_input: + expert_inputs = expert_inputs * final_scales[batch_idx, nth_expert, + None] output = (F.silu(expert_inputs @ w1_weight.t()) * (expert_inputs @ w3_weight.t())) @ w2_weight.t() - results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output + if not apply_routing_on_input: + output = output * final_scales[batch_idx, nth_expert, None] + results[batch_idx] += output + + return results.view_as(x) + + +def reference_bmm_moe_torch( + x: torch.Tensor, + selected_experts: torch.Tensor, + final_scales: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, + apply_routing_on_input: bool = True) -> torch.Tensor: + """Reference for stacked MoE (torch_moe_bmm) in TRT-LLM format. + + Args: + x: (seq_len, hidden_size) + selected_experts: (seq_len, topk) + final_scales: (seq_len, topk) + w3_w1_stacked_weight: (num_experts, 2*intermediate_size, hidden_size) - TRT-LLM format + w2_stacked_weight: (num_experts, hidden_size, intermediate_size) - TRT-LLM format + """ + num_experts = w3_w1_stacked_weight.shape[0] + intermediate_size = w3_w1_stacked_weight.shape[1] // 2 + results = torch.zeros_like(x) + + # Loop over experts (matches reference_moe_torch pattern) + for expert_id in range(num_experts): + batch_idx, nth_expert = torch.where(selected_experts == expert_id) + if len(batch_idx) == 0: + continue + + # Get weights for this expert (TRT-LLM format) + gate_up = w3_w1_stacked_weight[expert_id] # (2*I, H) + w3_weight = gate_up[:intermediate_size, :] # (I, H) + w1_weight = gate_up[intermediate_size:, :] # (I, H) + w2_weight = w2_stacked_weight[expert_id] # (H, I) + + expert_inputs = x[batch_idx] + if apply_routing_on_input: + expert_inputs = expert_inputs * final_scales[batch_idx, nth_expert, + None] + + # Gated MLP computation - TRT-LLM format uses .t() + output = (F.silu(expert_inputs @ w1_weight.t()) * + (expert_inputs @ w3_weight.t())) @ w2_weight.t() + + if not apply_routing_on_input: + output = output * final_scales[batch_idx, nth_expert, None] + + results[batch_idx] += output return results.view_as(x) From aef538c3b2dd0fe682fc86d8e36ece00f13b68c5 Mon Sep 17 00:00:00 2001 From: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Mon, 1 Dec 2025 00:25:13 -0800 Subject: [PATCH 2/5] fix(auto_deploy): Fix Llama4 MoE format detection and graph node ordering Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- .../transform/library/fused_moe.py | 47 ++++--- .../auto_deploy/utils/sharding_utils.py | 124 +++++++----------- 2 files changed, 77 insertions(+), 94 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 48f9860eaf1..76802ea551c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -70,14 +70,15 @@ def _process_regular_moe_node( # Register the stacked weights as parameters param_w_up = torch.nn.Parameter(fused_w_up_experts) gm.register_parameter(new_key_w_up, param_w_up) - w_up_arg = graph.get_attr(new_key_w_up) param_w_down = torch.nn.Parameter(fused_w_down_experts) gm.register_parameter(new_key_w_down, param_w_down) - w_down_arg = graph.get_attr(new_key_w_down) # Create fused MoE node - kernel applies routing to output with graph.inserting_before(node): + w_up_arg = graph.get_attr(new_key_w_up) + w_down_arg = graph.get_attr(new_key_w_down) + new_node = graph.call_function( replacement_op, args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg), @@ -122,12 +123,16 @@ def _process_llama4_stacked_moe_node( gate_up_weight = gm.get_parameter(w3_w1_stacked.target) down_weight = gm.get_parameter(w2_stacked.target) - # Detect format: Llama4 has gate_up[0] == down[1] - is_llama4 = gate_up_weight.shape[1] == down_weight.shape[0] + # Detect format: + # - Llama4: gate_up is (E, H, 2*I) and down is (E, I, H) + # - TRT-LLM: gate_up is (E, 2*I, H) and down is (E, H, I) + # If both have H in middle dimension, they're Llama4 format + is_llama4 = gate_up_weight.shape[1] == down_weight.shape[2] if is_llama4: - # Convert Llama4 (H, 2*I) -> TRT-LLM (2*I, H) + # Convert Llama4 (E, H, 2*I) -> TRT-LLM (E, 2*I, H) gate_up_trtllm = gate_up_weight.transpose(1, 2).contiguous() + # Convert Llama4 (E, I, H) -> TRT-LLM (E, H, I) down_trtllm = down_weight.transpose(1, 2).contiguous() # Register converted weights @@ -137,31 +142,39 @@ def _process_llama4_stacked_moe_node( gm.register_parameter(new_key_w_up, torch.nn.Parameter(gate_up_trtllm)) gm.register_parameter(new_key_w_down, torch.nn.Parameter(down_trtllm)) - w_up_arg = graph.get_attr(new_key_w_up) - w_down_arg = graph.get_attr(new_key_w_down) + # Store keys to create get_attr nodes later in insertion context + needs_get_attr = True + w_up_key = new_key_w_up + w_down_key = new_key_w_down else: # Already TRT-LLM format, use directly + needs_get_attr = False w_up_arg = w3_w1_stacked w_down_arg = w2_stacked else: # Not get_attr nodes (might be intermediate ops), use directly + needs_get_attr = False w_up_arg = w3_w1_stacked w_down_arg = w2_stacked - # Get weight dtype to ensure dtype consistency for Llama4 stacked tensors - # The fused kernel requires input and weights to have matching dtypes - weight_dtype = None - if w_up_arg.op == "get_attr": - try: - weight_tensor = gm.get_parameter(w_up_arg.target) - weight_dtype = weight_tensor.dtype - except (AttributeError, KeyError): - pass - # Llama4 INPUT-SIDE routing: apply routing to INPUT before kernel # Cast BOTH input and routing_weights to weight dtype if needed # Critical: BFloat16 * Float32 → Float32 (type promotion) so we cast both to same dtype with graph.inserting_before(node): + # Create get_attr nodes INSIDE insertion context for proper topological ordering + if needs_get_attr: + w_up_arg = graph.get_attr(w_up_key) + w_down_arg = graph.get_attr(w_down_key) + + # Get weight dtype to ensure dtype consistency for Llama4 stacked tensors + # The fused kernel requires input and weights to have matching dtypes + weight_dtype = None + if w_up_arg.op == "get_attr": + try: + weight_tensor = gm.get_parameter(w_up_arg.target) + weight_dtype = weight_tensor.dtype + except (AttributeError, KeyError): + pass input_to_scale = hidden_states routing_to_scale = routing_weights diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index affb138fbb7..dce7f2e46dc 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -1094,9 +1094,16 @@ def get_partition(lst, world_size, rank): def _slice_expert_dim( - gm: GraphModule, tensor_node_or_tensor: Union[Node, torch.Tensor], lo: int, hi: int + gm: GraphModule, + tensor_node_or_tensor: Union[Node, torch.Tensor], + lo: int, + hi: int, ) -> Union[Node, torch.Tensor]: - """Slice expert weights along dim 0 and register load hook. + """Slice expert weights along dim 0 and register load hook (simple version). + + This is the original simple slicing function used by MXFP4 EP sharding. + For parameters, it modifies them in-place and returns the same node. + Args: gm: The graph module tensor_node_or_tensor: Either a Node (from FX graph) or a Tensor @@ -1114,7 +1121,7 @@ def _slice_expert_dim( tensor_node = tensor_node_or_tensor if tensor_node.op != "get_attr": - # If not a parameter node, just add a runtime slice node + # If not a parameter node, just add a runtime slice node after it with gm.graph.inserting_after(tensor_node): return gm.graph.call_function( torch.ops.aten.slice.Tensor, @@ -1153,110 +1160,72 @@ def slice_expert_tensor(t: torch.Tensor) -> torch.Tensor: return tensor_node -def _slice_and_transpose_expert_dim( +def _transform_bmm_moe_weight_param( gm: GraphModule, - tensor_node_or_tensor: Union[Node, torch.Tensor], + param_node: Node, lo: int, hi: int, - transpose_dim1: int, - transpose_dim2: int, swap_gate_up: bool = False, -) -> Union[Node, torch.Tensor]: - """Slice expert weights along dim 0 and transpose, with load hook registration. +) -> None: + """Transform a parameter for BMM MoE: slice experts, optionally swap gate/up, transpose. - This is specifically for converting Llama4 MoE weight format to TRT-LLM format. - Llama4 stores weights as (E, H, 2*I) and (E, I, H), but TRT-LLM expects (E, 2*I, H) and (E, H, I). - For gate_up weights, Llama4 has [W1, W3] but TRT-LLM expects [W3, W1], so we swap them. + This modifies the parameter in-place and registers a load hook. + Does NOT create graph nodes - those should be created separately by the caller. Args: - gm: The graph module - tensor_node_or_tensor: Either a Node (from FX graph) or a Tensor - lo: Start index for slicing along expert dimension - hi: End index for slicing along expert dimension - transpose_dim1: First dimension to transpose - transpose_dim2: Second dimension to transpose - swap_gate_up: If True, swap W1 and W3 in the fused gate_up dimension - - Returns: - Node or Tensor depending on input type, sliced and transposed + gm: Graph module + param_node: The get_attr node for the parameter + lo: Start index for expert slicing + hi: End index for expert slicing + swap_gate_up: If True, swap W1 and W3 (Llama4 -> TRT-LLM format) """ - # Handle raw tensor case - if isinstance(tensor_node_or_tensor, torch.Tensor): - sliced = tensor_node_or_tensor[lo:hi] - if swap_gate_up: - # For gate_up: (E_shard, H, 2*I) -> swap last dim -> transpose - intermediate_size = sliced.shape[2] // 2 - w1 = sliced[:, :, :intermediate_size] - w3 = sliced[:, :, intermediate_size:] - sliced = torch.cat([w3, w1], dim=2) - return sliced.transpose(transpose_dim1, transpose_dim2).contiguous() - - # Handle Node case - tensor_node = tensor_node_or_tensor + if param_node.op != "get_attr": + return # Only works on parameters - if tensor_node.op != "get_attr": - # If not a parameter node, add runtime operations - with gm.graph.inserting_after(tensor_node): - sliced_node = gm.graph.call_function( - torch.ops.aten.slice.Tensor, - args=(tensor_node, 0, lo, hi, 1), - ) - if swap_gate_up: - # Would need complex graph operations, skip for now - ad_logger.warning(f"Cannot swap gate_up for non-parameter node {tensor_node}") - return gm.graph.call_function( - torch.ops.aten.transpose.int, - args=(sliced_node, transpose_dim1, transpose_dim2), - ) - - # Get the parameter - param_key = str(tensor_node.target) + param_key = str(param_node.target) modname, _, param_name = param_key.rpartition(".") submod = gm.get_submodule(modname) if modname else gm full_param = getattr(submod, param_name) - # Slice the parameter + # Slice the parameter along expert dimension (dim 0) sliced_param = full_param[lo:hi].detach().clone() # Swap W1 and W3 if needed (for gate_up weights) - if swap_gate_up: - # Llama4: (E, H, 2*I) with [W1, W3], TRT-LLM wants [W3, W1] + # Llama4: (E, H, 2*I) with [W1, W3], TRT-LLM wants [W3, W1] + if swap_gate_up and sliced_param.ndim == 3: intermediate_size = sliced_param.shape[2] // 2 w1 = sliced_param[:, :, :intermediate_size] w3 = sliced_param[:, :, intermediate_size:] sliced_param = torch.cat([w3, w1], dim=2) - # Transpose the parameter - transposed_param = sliced_param.transpose(transpose_dim1, transpose_dim2).contiguous() + # Transpose: Llama4 (E, H, X) -> TRT-LLM (E, X, H) + transposed_param = sliced_param.transpose(1, 2).contiguous() transposed_shape = transposed_param.shape - # Define slice+swap+transpose function for load hook - def slice_swap_and_transpose(t: torch.Tensor) -> torch.Tensor: + # Define transformation function for load hook + def transform_tensor(t: torch.Tensor) -> torch.Tensor: t_sliced = t[lo:hi] - if swap_gate_up: + if swap_gate_up and t_sliced.ndim == 3: intermediate_size = t_sliced.shape[2] // 2 w1 = t_sliced[:, :, :intermediate_size] w3 = t_sliced[:, :, intermediate_size:] t_sliced = torch.cat([w3, w1], dim=2) - return t_sliced.transpose(transpose_dim1, transpose_dim2).contiguous() + return t_sliced.transpose(1, 2).contiguous() - # Register load hook to slice+swap+transpose during checkpoint loading + # Register load hook gm._register_load_state_dict_pre_hook( partial( _load_hook, - f_split=slice_swap_and_transpose, + f_split=transform_tensor, param_key=param_key, param_shape=transposed_shape, ) ) - # Replace the parameter with the sliced+transposed version + # Replace the parameter with the transformed version new_param = nn.Parameter(transposed_param, requires_grad=full_param.requires_grad) setattr(submod, param_name, new_param) - # Return the same node (it now points to the transformed parameter) - return tensor_node - def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int: """Helper to get the first dimension size of an argument (Node or Tensor).""" @@ -1326,20 +1295,21 @@ def _insert_sharded_moe_bmm( "call_function", operator.mul, args=(final_scales, rank_mask), kwargs={} ) - # -- Slice expert weights along dim 0 (expert dimension) -- + # -- Transform expert weight parameters -- local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank) - # -- Update args -- + # Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H) + if isinstance(args[3], Node): + _transform_bmm_moe_weight_param(gm, args[3], local_lo, local_hi, swap_gate_up=True) + + # Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I) + if isinstance(args[4], Node): + _transform_bmm_moe_weight_param(gm, args[4], local_lo, local_hi, swap_gate_up=False) + + # -- Update args (keep same nodes, just with transformed parameters) -- args[1] = selected_experts_local args[2] = final_scales_local - # Slice, swap [W1,W3]->[W3,W1], and transpose for TRT-LLM format: Llama4 (E, H, 2*I) -> TRT-LLM (E, 2*I, H) - args[3] = _slice_and_transpose_expert_dim( - gm, args[3], local_lo, local_hi, 1, 2, swap_gate_up=True - ) # w3_w1_stacked - # Slice and transpose for TRT-LLM format: Llama4 (E, I, H) -> TRT-LLM (E, H, I) - args[4] = _slice_and_transpose_expert_dim( - gm, args[4], local_lo, local_hi, 1, 2, swap_gate_up=False - ) # w2_stacked + # args[3] and args[4] stay the same - we modified the parameters in-place ad_logger.debug( f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." From 440e8dcf2ce4caa255b45ed97127713c2130e970 Mon Sep 17 00:00:00 2001 From: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Mon, 1 Dec 2025 16:42:16 +0200 Subject: [PATCH 3/5] Update tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py Co-authored-by: Neta Zmora Signed-off-by: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> --- .../auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index 9d398588e10..a1aaa47f8e9 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -89,12 +89,10 @@ def setup_bmm_moe_test(dtype, num_experts): ).cuda() for expert_id in range(NUM_EXPERTS): - w1 = torch.rand(INTERMEDIATE_SIZE, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 + w31 = torch.rand(INTERMEDIATE_SIZE * 2, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 w2 = torch.rand(HIDDEN_SIZE, INTERMEDIATE_SIZE, dtype=dtype).cuda() * 0.1 - w3 = torch.rand(INTERMEDIATE_SIZE, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 - # TRT-LLM format: concat w3 and w1 along intermediate dim - w3_w1_stacked_weight.data[expert_id].copy_(torch.cat([w3, w1], dim=0)) # (2*I, H) + w3_w1_stacked_weight.data[expert_id].copy_(w31, dim=0)) # (2*I, H) w2_stacked_weight.data[expert_id].copy_(w2) # (H, I) return ( From 607a862ee2e107dd2e3022fa8bf8ba64eefbd21d Mon Sep 17 00:00:00 2001 From: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Mon, 1 Dec 2025 06:48:59 -0800 Subject: [PATCH 4/5] set requires_grad to False, remove contiguous Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index dce7f2e46dc..ff2a3f87f17 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -254,7 +254,7 @@ def split_fused_tensor( if update_param: modname, _, param_name = param_key.rpartition(".") submod = gm.get_submodule(modname) - param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=requires_grad) + param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=False) setattr(submod, param_name, param_new) return sharded_weight, sharded_shape @@ -1153,7 +1153,7 @@ def slice_expert_tensor(t: torch.Tensor) -> torch.Tensor: ) # Replace the parameter with the sliced version - new_param = nn.Parameter(sliced_param, requires_grad=full_param.requires_grad) + new_param = nn.Parameter(sliced_param, requires_grad=False) setattr(submod, param_name, new_param) # Return the same node (it now points to the sliced parameter) @@ -1199,7 +1199,7 @@ def _transform_bmm_moe_weight_param( sliced_param = torch.cat([w3, w1], dim=2) # Transpose: Llama4 (E, H, X) -> TRT-LLM (E, X, H) - transposed_param = sliced_param.transpose(1, 2).contiguous() + transposed_param = sliced_param.transpose(1, 2) transposed_shape = transposed_param.shape # Define transformation function for load hook @@ -1223,7 +1223,7 @@ def transform_tensor(t: torch.Tensor) -> torch.Tensor: ) # Replace the parameter with the transformed version - new_param = nn.Parameter(transposed_param, requires_grad=full_param.requires_grad) + new_param = nn.Parameter(transposed_param, requires_grad=False) setattr(submod, param_name, new_param) From 8eb2fff3013e2751a050a45bbb930d72fdd1f493 Mon Sep 17 00:00:00 2001 From: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Tue, 2 Dec 2025 01:29:22 -0800 Subject: [PATCH 5/5] consolidate moe_bmm to moe Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/torch_moe.py | 142 ++++++-------- .../transform/library/fused_moe.py | 182 ++++++++++++------ .../auto_deploy/transform/library/sharding.py | 31 +-- .../auto_deploy/utils/sharding_utils.py | 114 ++++++++--- .../library/test_ep_sharding.py | 18 +- .../singlegpu/custom_ops/test_ad_moe_op.py | 15 +- tests/unittest/_torch/helpers.py | 2 +- 7 files changed, 290 insertions(+), 214 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index fe0ebeeb665..f023062df15 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -68,13 +68,13 @@ def _template_moe( if apply_routing_on_input: # INPUT-SIDE routing (BMM-based pattern): multiply routing weights with INPUT - # Result: silu(input * routing_weight) - routing affects activation + # Result: silu(input * routing_weight) scaled_input = tokens_for_this_expert * routing_weights[top_x, idx, None] expert_out = mlps[expert_idx](scaled_input) current_hidden_states = expert_out else: # OUTPUT-SIDE routing (standard pattern): multiply routing weights with OUTPUT - # Result: silu(input) * routing_weight - routing applied after activation + # Result: silu(input) * routing_weight expert_out = mlps[expert_idx](tokens_for_this_expert) current_hidden_states = expert_out * routing_weights[top_x, idx, None] @@ -94,10 +94,16 @@ def torch_moe( w3_weight: List[torch.Tensor], mlp_style: str = "gated_mlp", act_fn: str = "silu", + apply_routing_on_input: bool = False, ) -> torch.Tensor: """ Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch (token routing + index_add_ accumulation) and a selectable per-expert MLP. + + Supports both: + - Standard MoE with per-expert weight lists (apply_routing_on_input=False) + - Llama4 MoE with stacked weight tensors (apply_routing_on_input=True) + Parameters: x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size, S is the sequence length, and H is the hidden size. @@ -105,35 +111,71 @@ def torch_moe( of the selected experts for each token. Only experts within range [0,num_experts) is processed routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized routing weights for the selected experts. + - Standard MoE: softmax normalized weights + - Llama4 MoE: sigmoid activated weights w1_weight: - List of per-expert weight tensors: - • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": W_up with shape (I, H) — up projection. + For per-expert lists: + • mlp_style=="gated_mlp": List of W1 with shape (I, H) — "gate" projection. + • mlp_style=="mlp": List of W_up with shape (I, H) — up projection. + For stacked tensors (Llama4): + • Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format w2_weight: - List of per-expert weight tensors: - • gated_mlp: W2 with shape (H, I) — down projection. - • mlp: W_down with shape (H, I) — down projection. + For per-expert lists: + • List of W2/W_down with shape (H, I) — down projection. + For stacked tensors (Llama4): + • Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format w3_weight: - List of per-expert weight tensors: - • gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP. - • mlp: pass an empty list []; ignored. + For per-expert lists with gated_mlp: + • List of W3 with shape (I, H) — "up" (second) projection in gated MLP. + For mlp style or stacked tensors: + • pass an empty list []; ignored. mlp_style: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek-style): + • "gated_mlp" (default, Mixtral/DeepSeek/Llama4-style): y = W2( act(W1 x) * (W3 x) ) • "mlp" (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. Supported: "silu" (default), "relu2" (ReLU then square). + apply_routing_on_input: + If True (Llama4 pattern): multiply routing weights with INPUT before MLP + Result: act(input * routing_weight) - routing affects activation + If False (standard pattern): multiply routing weights with OUTPUT after MLP + Result: act(input) * routing_weight - routing scales output Returns: torch.Tensor: Output tensor with the same shape as the input x. """ act_fn = _resolve_activation(act_fn) style = mlp_style.lower() - if style == "gated_mlp": + # Detect if using stacked tensor format (Llama4) vs per-expert lists (standard) + is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3 + + if is_stacked: + # Llama4 stacked tensor format - only supports gated_mlp + if style != "gated_mlp": + raise ValueError("Stacked tensor format only supports 'gated_mlp' style") + + w3_w1_stacked = w1_weight[0] # (E, 2*I, H) + w2_stacked = w2_weight[0] # (E, H, I) + def make_mlp(i: int): + gate_up = w3_w1_stacked[i] # (2*I, H) + intermediate_size = gate_up.shape[0] // 2 + W3 = gate_up[:intermediate_size, :] # (I, H) + W1 = gate_up[intermediate_size:, :] # (I, H) + W2 = w2_stacked[i] # (H, I) + weight_dtype = W1.dtype + return lambda inp: F.linear( + act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3), + W2, + ) + + mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])] + + elif style == "gated_mlp": + # Standard per-expert list format with gated MLP def make_mlp(i: int): W1 = w1_weight[i] # (I, H) W2 = w2_weight[i] # (H, I) @@ -143,7 +185,7 @@ def make_mlp(i: int): mlps = [make_mlp(i) for i in range(len(w1_weight))] elif style == "mlp": - + # Standard per-expert list format with simple MLP def make_mlp(i: int): W_up = w1_weight[i] # (I, H) W_down = w2_weight[i] # (H, I) @@ -154,7 +196,7 @@ def make_mlp(i: int): else: raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - return _template_moe(x, selected_experts, routing_weights, mlps) + return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input) @torch_moe.register_fake @@ -167,75 +209,7 @@ def torch_moe_fake( w3_weight: List[torch.Tensor], mlp_style: str = "gated_mlp", act_fn: str = "silu", -) -> torch.Tensor: - return torch.empty_like(x) - - -@torch.library.custom_op("auto_deploy::torch_moe_bmm", mutates_args=()) -def torch_moe_bmm( - x: torch.Tensor, - selected_experts: torch.Tensor, - routing_weights: torch.Tensor, - w3_w1_stacked: torch.Tensor, - w2_stacked: torch.Tensor, - act_fn: str = "silu", - apply_routing_on_input: bool = True, -) -> torch.Tensor: - """MoE operator using stacked weights for expert computation. - - Only supports gated MLP (SwiGLU-style) architecture. - Uses stacked expert weights in TRT-LLM format (conversion from Llama4 happens during graph transformation). - - Weights are expected in TRT-LLM format: - - w3_w1_stacked: (NUM_EXPERTS, 2*INTERMEDIATE_SIZE, HIDDEN_SIZE) - - w2_stacked: (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) - - Args: - x: Input tensor - selected_experts: Expert indices for each token - routing_weights: Routing weights for each selected expert - w3_w1_stacked: Stacked gate/up weights (TRT-LLM format) - w2_stacked: Stacked down weights (TRT-LLM format) - act_fn: Activation function name (default: "silu") - apply_routing_on_input: If True, apply routing to INPUT (default for Llama4 pattern). - If False, apply routing to OUTPUT (alternative pattern). - """ - - act_fn = _resolve_activation(act_fn) - - # Standardized on TRT-LLM format (conversion happens during graph transformation) - # Only gated MLP supported: output = (act_fn(x @ W1) * (x @ W3)) @ W2 - def make_mlp(i: int): - gate_up = w3_w1_stacked[i] # (2*I, H) - intermediate_size = gate_up.shape[0] // 2 - W3 = gate_up[:intermediate_size, :] # (I, H) - W1 = gate_up[intermediate_size:, :] # (I, H) - W2 = w2_stacked[i] # (H, I) - weight_dtype = W1.dtype - return lambda inp: F.linear( - act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3), - W2, - ) - - mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])] - - # Apply routing based on the specified pattern - # INPUT-SIDE (default): silu(input * routing_weight) - routing affects activation - # OUTPUT-SIDE (alternative): silu(input) * routing_weight - routing scales output - return _template_moe( - x, selected_experts, routing_weights, mlps, apply_routing_on_input=apply_routing_on_input - ) - - -@torch_moe_bmm.register_fake -def torch_moe_bmm_fake( - x: torch.Tensor, - selected_experts: torch.Tensor, - routing_weights: torch.Tensor, - w3_w1_stacked: torch.Tensor, - w2_stacked: torch.Tensor, - act_fn: str = "silu", - apply_routing_on_input: bool = True, + apply_routing_on_input: bool = False, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 76802ea551c..af0865c183f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -18,6 +18,85 @@ ) +def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: + """Replace torch MoE ops with fused backend-specific implementations. + + Handles both: + - Standard MoE (per-expert weight lists): torch_moe with apply_routing_on_input=False + - Llama4 MoE (pre-stacked weight tensors): torch_moe with apply_routing_on_input=True + + For Llama4 stacked tensors, applies routing weights to input before the fused kernel. + """ + fused_key_counter = 0 + graph = gm.graph + backend = backend.lower() + + # Map backend to fused MoE op (handles both standard and Llama4 stacked tensor patterns) + replacement_op = { + "auto": torch.ops.auto_deploy.trtllm_moe_fused, + "trtllm": torch.ops.auto_deploy.trtllm_moe_fused, + "triton": torch.ops.auto_deploy.triton_moe_fused, + }[backend] + + for node in graph.nodes: + if not is_op(node, torch.ops.auto_deploy.torch_moe): + continue + + # Detect if this is a stacked MoE (Llama4 pattern) or per-expert list (standard pattern) + (apply_routing_val, w1_weight_list) = extract_op_args( + node, "apply_routing_on_input", "w1_weight" + ) + + # Check if it's stacked format: single-element list with 3D tensor + is_stacked_moe = False + if apply_routing_val: + # In FX graphs, w1_weight_list might be a Node representing a list() call + list_content = None + if isinstance(w1_weight_list, Node) and w1_weight_list.target is list: + # Extract from list() call node + if w1_weight_list.args: + list_content = w1_weight_list.args[0] + elif isinstance(w1_weight_list, (list, tuple)): + # Direct Python list + list_content = w1_weight_list + + # Check if it's a single-element list with a 3D tensor + if list_content is not None and len(list_content) == 1: + w1_node = list_content[0] + if isinstance(w1_node, Node) and w1_node.op == "get_attr": + try: + w1_tensor = gm.get_parameter(w1_node.target) + is_stacked_moe = w1_tensor.ndim == 3 + except (AttributeError, KeyError): + pass + + if is_stacked_moe: + # Stacked MoE (Llama4 pattern): only supports gated MLP + (act_fn_val,) = extract_op_args(node, "act_fn") + _process_llama4_stacked_moe_node( + gm, graph, node, replacement_op, act_fn_val, fused_key_counter + ) + else: + # Standard MoE with per-expert weight lists + (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") + assert backend != "triton" or mlp_style_val == "mlp", ( + "Triton backend only supports mlp style." + ) + _process_regular_moe_node( + gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter + ) + + fused_key_counter += 1 + + # Delete the unstacked weights immediately to save GPU memory + # This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory + # during the transformation itself. + gm.graph.eliminate_dead_code() + gm.delete_all_unused_submodules() + + return fused_key_counter + + def _process_regular_moe_node( gm: GraphModule, graph: torch.fx.Graph, @@ -100,7 +179,7 @@ def _process_llama4_stacked_moe_node( act_fn_val: str, fused_key_counter: int, ) -> None: - """Process a single Llama4 MoE node with pre-stacked weight tensors (torch_moe_bmm). + """Process a single Llama4 MoE node with pre-stacked weight tensors. Only supports gated MLP (SwiGLU-style) architecture. Converts Llama4 format weights to TRT-LLM format to standardize all downstream ops. @@ -108,15 +187,31 @@ def _process_llama4_stacked_moe_node( This is the Llama4 pattern where weights are already stacked across experts. Result: silu(input * routing_weight) - routing affects activation. """ - hidden_states, selected_experts, routing_weights, w3_w1_stacked, w2_stacked = extract_op_args( + # torch_moe with stacked format: weights are in single-element lists + hidden_states, selected_experts, routing_weights, w1_list, w2_list = extract_op_args( node, "x", "selected_experts", "routing_weights", - "w3_w1_stacked", - "w2_stacked", + "w1_weight", + "w2_weight", ) + # Extract the single stacked tensor from each list + # Handle both FX graph Nodes (list() calls) and direct Python lists + def extract_from_list_arg(list_arg): + if isinstance(list_arg, Node) and list_arg.target is list: + # Extract from list() call node + return list_arg.args[0][0] if list_arg.args else None + elif isinstance(list_arg, (list, tuple)): + # Direct Python list + return list_arg[0] + else: + raise ValueError(f"Unexpected list format: {type(list_arg)}") + + w3_w1_stacked = extract_from_list_arg(w1_list) + w2_stacked = extract_from_list_arg(w2_list) + # Convert Llama4 format to TRT-LLM format if needed # This standardizes all downstream ops to only handle TRT-LLM format if w3_w1_stacked.op == "get_attr" and w2_stacked.op == "get_attr": @@ -214,62 +309,6 @@ def _process_llama4_stacked_moe_node( graph.erase_node(node) -def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: - """Replace torch MoE ops with fused backend-specific implementations. - - Handles both: - - Standard MoE (per-expert weight lists): torch_moe - - Llama4 MoE (pre-stacked weight tensors): torch_moe_bmm - - For Llama4 stacked tensors, applies routing weights to input before the fused kernel. - """ - fused_key_counter = 0 - graph = gm.graph - backend = backend.lower() - - # Map backend to fused MoE op (handles both standard and Llama4 stacked tensor patterns) - replacement_op = { - "auto": torch.ops.auto_deploy.trtllm_moe_fused, - "trtllm": torch.ops.auto_deploy.trtllm_moe_fused, - "triton": torch.ops.auto_deploy.triton_moe_fused, - }[backend] - - for node in graph.nodes: - if not ( - is_op(node, torch.ops.auto_deploy.torch_moe) - or is_op(node, torch.ops.auto_deploy.torch_moe_bmm) - ): - continue - - is_llama4_stacked_moe = is_op(node, torch.ops.auto_deploy.torch_moe_bmm) - - if is_llama4_stacked_moe: - # torch_moe_bmm: only supports gated MLP, no mlp_style parameter - (act_fn_val,) = extract_op_args(node, "act_fn") - _process_llama4_stacked_moe_node( - gm, graph, node, replacement_op, act_fn_val, fused_key_counter - ) - else: - # torch_moe: supports both gated_mlp and mlp styles - (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") - assert backend != "triton" or mlp_style_val == "mlp", ( - "Triton backend only supports mlp style." - ) - _process_regular_moe_node( - gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter - ) - - fused_key_counter += 1 - - # Delete the unstacked weights immediately to save GPU memory - # This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory - # during the transformation itself. - gm.graph.eliminate_dead_code() - gm.delete_all_unused_submodules() - - return fused_key_counter - - def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]: """ Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following @@ -1174,17 +1213,34 @@ def _apply( # If input_routing is False: kernel applies routing to output apply_routing_on_input = input_routing + # Wrap stacked tensors in single-element lists for torch_moe unified interface with graph.inserting_before(output_node): + # Create list nodes for stacked weights + w1_list_node = graph.call_function( + list, + args=([gate_up_weight],), + ) + w2_list_node = graph.call_function( + list, + args=([down_weight],), + ) + w3_list_node = graph.call_function( + list, + args=([],), # Empty list for stacked gated MLP + ) + fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_moe_bmm, + torch.ops.auto_deploy.torch_moe, args=( input_hidden_states, selected_experts, routing_weights_node, - gate_up_weight, - down_weight, + w1_list_node, + w2_list_node, + w3_list_node, ), kwargs={ + "mlp_style": "gated_mlp", "apply_routing_on_input": apply_routing_on_input, }, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 486c17b1ca1..1ed5652c628 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -39,7 +39,6 @@ subgraph, ) from ...utils.sharding_utils import ( - BmmEPShardingInfo, BMMShardingInfo, EPShardingInfo, LayerType, @@ -138,9 +137,6 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: for ep_transform in transforms.ep_transforms: if check_and_apply(ep_transform): num_matches += 1 - for stacked_ep_transform in transforms.stacked_ep_transforms: - if check_and_apply(stacked_ep_transform): - num_matches += 1 # post-sharding cleanup transformations for update_transform in transforms.parameter_update_transforms: @@ -1062,31 +1058,20 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingTransformContainer node, ( torch.ops.auto_deploy.torch_moe, - torch.ops.auto_deploy.torch_moe_bmm, torch.ops.auto_deploy.torch_quant_fp8_moe, torch.ops.auto_deploy.torch_quant_nvfp4_moe, torch.ops.auto_deploy.triton_mxfp4_moe, ), ): continue - if is_op(node, torch.ops.auto_deploy.torch_moe_bmm): - if sharding_config.add( - BmmEPShardingInfo.from_node( - node, - rank=rank, - world_size=world_size, - ) - ): - num_moe_patterns += 1 - else: - if sharding_config.add( - EPShardingInfo.from_node( - node, - rank=rank, - world_size=world_size, - ) - ): - num_moe_patterns += 1 + if sharding_config.add( + EPShardingInfo.from_node( + node, + rank=rank, + world_size=world_size, + ) + ): + num_moe_patterns += 1 ad_logger.info(f"Found {num_moe_patterns} MoE patterns") diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index ff2a3f87f17..64698911ddb 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -254,7 +254,7 @@ def split_fused_tensor( if update_param: modname, _, param_name = param_key.rpartition(".") submod = gm.get_submodule(modname) - param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=False) + param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=requires_grad) setattr(submod, param_name, param_new) return sharded_weight, sharded_shape @@ -1008,17 +1008,66 @@ def _insert_sharded_moe( allreduce_strategy: AllReduceStrategy, scale_names: Sequence[str] = (), ): - """Update the torch_moe node with sharded weight lists, + """Update the torch_moe node with sharded weight lists or stacked tensors, sharded `selected_experts` and `final_scales(router_logics)`. Add an all_reduce node after the moe node. + Handles both: + - Standard format: per-expert weight lists + - Stacked format: single-element lists containing stacked 3D tensors (Llama4 pattern) + NOTE: allreduce_strategy is MANDATORY. """ if allreduce_strategy is None: raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}") scale_names = list(scale_names) - num_experts = len(node.args[3]) + # Detect format: check if w1_weight is a single-element list with a 3D tensor (stacked format) + w1_weight_arg = node.args[3] + is_stacked = False + + # In FX graphs, the list might be a Node representing a list() call + if isinstance(w1_weight_arg, Node): + # Check if this is a list() call node + if w1_weight_arg.target is list and len(w1_weight_arg.args) > 0: + # Get the actual list content from the args + list_content = w1_weight_arg.args[0] + if isinstance(list_content, (list, tuple)) and len(list_content) == 1: + first_elem = list_content[0] + if isinstance(first_elem, Node) and first_elem.op == "get_attr": + try: + tensor = gm.get_parameter(first_elem.target) + is_stacked = tensor.ndim == 3 + except (AttributeError, KeyError): + pass + elif isinstance(first_elem, torch.Tensor): + is_stacked = first_elem.ndim == 3 + # Handle case where it's a direct Python list (not in FX graph context) + elif isinstance(w1_weight_arg, (list, tuple)) and len(w1_weight_arg) == 1: + first_elem = w1_weight_arg[0] + if isinstance(first_elem, Node) and first_elem.op == "get_attr": + try: + tensor = gm.get_parameter(first_elem.target) + is_stacked = tensor.ndim == 3 + except (AttributeError, KeyError): + pass + elif isinstance(first_elem, torch.Tensor): + is_stacked = first_elem.ndim == 3 + + if is_stacked: + # Use stacked tensor sharding logic (similar to _insert_sharded_moe_bmm) + _insert_sharded_moe_stacked(gm, node, rank, world_size, allreduce_strategy, scale_names) + return + + # Standard per-expert list sharding + # For FX graphs, get the list from the Node; for direct calls, use the list directly + if isinstance(w1_weight_arg, Node) and w1_weight_arg.target is list: + # Extract the list content from the list() call node + num_experts = len(w1_weight_arg.args[0]) if w1_weight_arg.args else 0 + elif isinstance(w1_weight_arg, (list, tuple)): + num_experts = len(w1_weight_arg) + else: + raise ValueError(f"Unexpected w1_weight format in node {node.name}: {type(w1_weight_arg)}") args = list(node.args) # -- Handle selected_experts and final_scales sharding -- @@ -1243,7 +1292,7 @@ def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int: raise ValueError(f"Cannot determine shape[0] for {arg}") -def _insert_sharded_moe_bmm( +def _insert_sharded_moe_stacked( gm: GraphModule, node: Node, rank: int, @@ -1251,20 +1300,35 @@ def _insert_sharded_moe_bmm( allreduce_strategy: AllReduceStrategy, scale_names: Sequence[str] = (), ): - """Update the torch_moe_bmm node with sliced weight tensors, + """Update the torch_moe node with sliced stacked weight tensors, sharded `selected_experts` and `final_scales(router_logics)`. Add an all_reduce node after the moe node. - Similar to _insert_sharded_moe but for BMM-based (stacked) tensors instead of lists. + + For torch_moe with stacked tensor format (single-element lists containing 3D tensors). NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. """ if allreduce_strategy is None: - raise ValueError(f"allreduce_strategy must be set for MoE BMM sharding on node {node.name}") - scale_names = list(scale_names) + raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}") - # Get num_experts from the first weight tensor (w3_w1_stacked) - w3_w1_tensor_arg = node.args[3] - num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_arg) + # Extract the stacked tensors from single-element lists + # args[3] = w1_weight (Node representing list with one 3D tensor, or direct list) + # args[4] = w2_weight (Node representing list with one 3D tensor, or direct list) + + # Helper to extract tensor node from list (handles both Node and direct list) + def extract_tensor_from_list_arg(list_arg): + if isinstance(list_arg, Node) and list_arg.target is list: + # It's a list() call node - extract from its args + return list_arg.args[0][0] # args[0] is the list content, [0] is first element + elif isinstance(list_arg, (list, tuple)): + # Direct list + return list_arg[0] + else: + raise ValueError(f"Unexpected list format: {type(list_arg)}") + + w3_w1_tensor_node = extract_tensor_from_list_arg(node.args[3]) + w2_tensor_node = extract_tensor_from_list_arg(node.args[4]) + num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_node) args = list(node.args) @@ -1299,14 +1363,16 @@ def _insert_sharded_moe_bmm( local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank) # Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H) - if isinstance(args[3], Node): - _transform_bmm_moe_weight_param(gm, args[3], local_lo, local_hi, swap_gate_up=True) + if isinstance(w3_w1_tensor_node, Node): + _transform_bmm_moe_weight_param( + gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True + ) # Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I) - if isinstance(args[4], Node): - _transform_bmm_moe_weight_param(gm, args[4], local_lo, local_hi, swap_gate_up=False) + if isinstance(w2_tensor_node, Node): + _transform_bmm_moe_weight_param(gm, w2_tensor_node, local_lo, local_hi, swap_gate_up=False) - # -- Update args (keep same nodes, just with transformed parameters) -- + # -- Update args (keep same lists/nodes, just with transformed parameters) -- args[1] = selected_experts_local args[2] = final_scales_local # args[3] and args[4] stay the same - we modified the parameters in-place @@ -1475,24 +1541,10 @@ def apply(self, gm: GraphModule, node: Node) -> None: ) -class BmmEPShardingInfo(EPShardingInfo): - """EP sharding behavior for torch_moe_bmm (BMM-based MoE with stacked weights).""" - - def validate(self, gm: GraphModule = None, node: Node = None) -> bool: - if not is_op(node, torch.ops.auto_deploy.torch_moe_bmm): - ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") - return False - return True - - def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_moe_bmm(gm, node, self.rank, self.world_size, self.allreduce_strategy) - - EP_SHARDING_RULES = [ (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_fp8_moe), FP8EPShardingInfo), (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_moe), NVFP4EPShardingInfo), (lambda n: is_op(n, torch.ops.auto_deploy.torch_moe), EPShardingInfo), - (lambda n: is_op(n, torch.ops.auto_deploy.torch_moe_bmm), BmmEPShardingInfo), (lambda n: is_op(n, torch.ops.auto_deploy.triton_mxfp4_moe), MXFP4EPShardingInfo), ] @@ -1550,7 +1602,6 @@ class ShardingTransformContainer(BaseModel): parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) - stacked_ep_transforms: List[BmmEPShardingInfo] = Field(default_factory=list) @field_validator("allreduce_strategy", mode="before") @classmethod @@ -1565,7 +1616,6 @@ def __init__(self, **kwargs): BMMShardingInfo: self.bmm_transforms, EPShardingInfo: self.ep_transforms, ParameterUpdateInfo: self.parameter_update_transforms, - BmmEPShardingInfo: self.stacked_ep_transforms, } def init_params( diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 171487fff7f..8112889a870 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -160,8 +160,8 @@ def test_sharding_pattern_detection(world_size: int, num_experts: int): def test_llama4_stacked_moe_pattern_detection(): - """Minimal test: verify torch_moe_bmm is detected for EP sharding.""" - # Create a simple graph with torch_moe_bmm node + """Minimal test: verify torch_moe with stacked format is detected for EP sharding.""" + # Create a simple graph with torch_moe node using stacked tensor format gm = torch.fx.GraphModule({}, torch.fx.Graph()) graph = gm.graph @@ -172,9 +172,15 @@ def test_llama4_stacked_moe_pattern_detection(): w3_w1 = graph.placeholder("w3_w1_stacked") w2 = graph.placeholder("w2_stacked") + # Create single-element lists for stacked tensor format + w1_list = graph.call_function(list, args=([w3_w1],)) + w2_list = graph.call_function(list, args=([w2],)) + w3_list = graph.call_function(list, args=([],)) + moe_node = graph.call_function( - torch.ops.auto_deploy.torch_moe_bmm, - args=(x, selected_experts, routing_weights, w3_w1, w2), + torch.ops.auto_deploy.torch_moe, + args=(x, selected_experts, routing_weights, w1_list, w2_list, w3_list), + kwargs={"mlp_style": "gated_mlp", "apply_routing_on_input": True}, ) graph.output(moe_node) @@ -195,7 +201,7 @@ def test_llama4_stacked_moe_pattern_detection(): optimizer.shared_config.world_size = 2 _ = optimizer(None, gm) - # Verify torch_moe_bmm is detected for EP sharding - detected = optimizer.shared_config.sharding_config.stacked_ep_transforms + # Verify torch_moe with stacked format is detected for EP sharding + detected = optimizer.shared_config.sharding_transform_container.ep_transforms assert len(detected) == 1, f"Expected 1 EP transform, got {len(detected)}" assert detected[0].target_node == moe_node.name diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index a1aaa47f8e9..99fccfab304 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -63,7 +63,7 @@ def setup_moe_test(dtype, num_experts): def setup_bmm_moe_test(dtype, num_experts): - """Setup for stacked MoE (torch_moe_bmm) with topk=1 in TRT-LLM format.""" + """Setup for stacked MoE with topk=1 in TRT-LLM format.""" SEQ_LEN = 8 HIDDEN_SIZE = 64 INTERMEDIATE_SIZE = 32 @@ -92,7 +92,7 @@ def setup_bmm_moe_test(dtype, num_experts): w31 = torch.rand(INTERMEDIATE_SIZE * 2, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 w2 = torch.rand(HIDDEN_SIZE, INTERMEDIATE_SIZE, dtype=dtype).cuda() * 0.1 # TRT-LLM format: concat w3 and w1 along intermediate dim - w3_w1_stacked_weight.data[expert_id].copy_(w31, dim=0)) # (2*I, H) + w3_w1_stacked_weight.data[expert_id].copy_(w31) # (2*I, H) w2_stacked_weight.data[expert_id].copy_(w2) # (H, I) return ( @@ -165,12 +165,17 @@ def test_bmm_based_moe_op_run(dtype): with torch.inference_mode(): x = final_scales * x selected_experts = torch.ones_like(selected_experts) - output_torch_moe = torch.ops.auto_deploy.torch_moe_bmm( + # Use torch_moe with stacked tensor format (single-element lists) + output_torch_moe = torch.ops.auto_deploy.torch_moe( x, selected_experts, final_scales, - fused_w3_w1_stacked_weight, - fused_w2_weight, + [fused_w3_w1_stacked_weight], # Wrap in list for unified interface + [fused_w2_weight], # Wrap in list for unified interface + [], # Empty w3_weight list for stacked gated MLP + mlp_style="gated_mlp", + act_fn="silu", + apply_routing_on_input=True, ) output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused( x, diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index e5dd67effa5..ae7aa5b1c52 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -128,7 +128,7 @@ def reference_bmm_moe_torch( w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, apply_routing_on_input: bool = True) -> torch.Tensor: - """Reference for stacked MoE (torch_moe_bmm) in TRT-LLM format. + """Reference for stacked MoE in TRT-LLM format. Args: x: (seq_len, hidden_size)