diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 2f26b99b821..4ca6729949c 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -28,6 +28,8 @@ transforms: stage: pattern_matcher match_dense_moe_pattern: stage: pattern_matcher + match_bmm_moe_pattern: + stage: pattern_matcher match_repeat_kv: stage: pattern_matcher run_shape_prop: true diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index 63519cf5a58..f023062df15 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -31,8 +31,20 @@ def _template_moe( selected_experts: torch.Tensor, routing_weights: torch.Tensor, mlps: List[Callable[[torch.Tensor], torch.Tensor]], + apply_routing_on_input: bool = False, ) -> torch.Tensor: - """Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info.""" + """Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info. + + Args: + x: Input tensor + selected_experts: Expert indices for each token + routing_weights: Routing weights for each expert + mlps: List of MLP functions + apply_routing_on_input: If True, multiply routing weights with INPUT before MLP (BMM-based pattern). + This means: silu(input * routing_weight) + If False, multiply routing weights with OUTPUT after MLP (standard pattern). + This means: silu(input) * routing_weight + """ x_shape = x.shape hidden_dim = x_shape[-1] x = x.view(-1, hidden_dim) @@ -54,8 +66,18 @@ def _template_moe( if not tokens_for_this_expert.shape[0]: continue # input of shape [0, hidden_dim] breaks fp4 kernel - expert_out = mlps[expert_idx](tokens_for_this_expert) - current_hidden_states = expert_out * routing_weights[top_x, idx, None] + if apply_routing_on_input: + # INPUT-SIDE routing (BMM-based pattern): multiply routing weights with INPUT + # Result: silu(input * routing_weight) + scaled_input = tokens_for_this_expert * routing_weights[top_x, idx, None] + expert_out = mlps[expert_idx](scaled_input) + current_hidden_states = expert_out + else: + # OUTPUT-SIDE routing (standard pattern): multiply routing weights with OUTPUT + # Result: silu(input) * routing_weight + expert_out = mlps[expert_idx](tokens_for_this_expert) + current_hidden_states = expert_out * routing_weights[top_x, idx, None] + final_hidden_states.index_add_( 0, top_x, current_hidden_states.to(final_hidden_states.dtype) ) @@ -72,10 +94,16 @@ def torch_moe( w3_weight: List[torch.Tensor], mlp_style: str = "gated_mlp", act_fn: str = "silu", + apply_routing_on_input: bool = False, ) -> torch.Tensor: """ Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch (token routing + index_add_ accumulation) and a selectable per-expert MLP. + + Supports both: + - Standard MoE with per-expert weight lists (apply_routing_on_input=False) + - Llama4 MoE with stacked weight tensors (apply_routing_on_input=True) + Parameters: x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size, S is the sequence length, and H is the hidden size. @@ -83,35 +111,71 @@ def torch_moe( of the selected experts for each token. Only experts within range [0,num_experts) is processed routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized routing weights for the selected experts. + - Standard MoE: softmax normalized weights + - Llama4 MoE: sigmoid activated weights w1_weight: - List of per-expert weight tensors: - • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": W_up with shape (I, H) — up projection. + For per-expert lists: + • mlp_style=="gated_mlp": List of W1 with shape (I, H) — "gate" projection. + • mlp_style=="mlp": List of W_up with shape (I, H) — up projection. + For stacked tensors (Llama4): + • Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format w2_weight: - List of per-expert weight tensors: - • gated_mlp: W2 with shape (H, I) — down projection. - • mlp: W_down with shape (H, I) — down projection. + For per-expert lists: + • List of W2/W_down with shape (H, I) — down projection. + For stacked tensors (Llama4): + • Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format w3_weight: - List of per-expert weight tensors: - • gated_mlp: W3 with shape (I, H) — “up” (second) projection in gated MLP. - • mlp: pass an empty list []; ignored. + For per-expert lists with gated_mlp: + • List of W3 with shape (I, H) — "up" (second) projection in gated MLP. + For mlp style or stacked tensors: + • pass an empty list []; ignored. mlp_style: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek-style): + • "gated_mlp" (default, Mixtral/DeepSeek/Llama4-style): y = W2( act(W1 x) * (W3 x) ) • "mlp" (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. Supported: "silu" (default), "relu2" (ReLU then square). + apply_routing_on_input: + If True (Llama4 pattern): multiply routing weights with INPUT before MLP + Result: act(input * routing_weight) - routing affects activation + If False (standard pattern): multiply routing weights with OUTPUT after MLP + Result: act(input) * routing_weight - routing scales output Returns: torch.Tensor: Output tensor with the same shape as the input x. """ act_fn = _resolve_activation(act_fn) style = mlp_style.lower() - if style == "gated_mlp": + # Detect if using stacked tensor format (Llama4) vs per-expert lists (standard) + is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3 + if is_stacked: + # Llama4 stacked tensor format - only supports gated_mlp + if style != "gated_mlp": + raise ValueError("Stacked tensor format only supports 'gated_mlp' style") + + w3_w1_stacked = w1_weight[0] # (E, 2*I, H) + w2_stacked = w2_weight[0] # (E, H, I) + + def make_mlp(i: int): + gate_up = w3_w1_stacked[i] # (2*I, H) + intermediate_size = gate_up.shape[0] // 2 + W3 = gate_up[:intermediate_size, :] # (I, H) + W1 = gate_up[intermediate_size:, :] # (I, H) + W2 = w2_stacked[i] # (H, I) + weight_dtype = W1.dtype + return lambda inp: F.linear( + act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3), + W2, + ) + + mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])] + + elif style == "gated_mlp": + # Standard per-expert list format with gated MLP def make_mlp(i: int): W1 = w1_weight[i] # (I, H) W2 = w2_weight[i] # (H, I) @@ -121,7 +185,7 @@ def make_mlp(i: int): mlps = [make_mlp(i) for i in range(len(w1_weight))] elif style == "mlp": - + # Standard per-expert list format with simple MLP def make_mlp(i: int): W_up = w1_weight[i] # (I, H) W_down = w2_weight[i] # (H, I) @@ -132,7 +196,7 @@ def make_mlp(i: int): else: raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - return _template_moe(x, selected_experts, routing_weights, mlps) + return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input) @torch_moe.register_fake @@ -145,6 +209,7 @@ def torch_moe_fake( w3_weight: List[torch.Tensor], mlp_style: str = "gated_mlp", act_fn: str = "silu", + apply_routing_on_input: bool = False, ) -> torch.Tensor: return torch.empty_like(x) @@ -159,6 +224,9 @@ def torch_fused_moe( ) -> torch.Tensor: """ A reference implementation of a fused MoE layer computation. + + All weights are in TRT-LLM format (conversion from Llama4 happens during graph transformation). + Parameters: x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size, S is the sequence length, and H is the hidden size. @@ -166,16 +234,19 @@ def torch_fused_moe( indices of the selected experts for each token. routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized routing weights for the selected experts. - w3_w1_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE) - containing the fused weights for w3 and w1 for each expert. - w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) - containing the weights for w2 for each expert. + w3_w1_stacked_weight (torch.Tensor): Stacked gate/up weights in TRT-LLM format: + (NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE) + w2_stacked_weight (torch.Tensor): Stacked down weights in TRT-LLM format: + (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) Returns: torch.Tensor: Output tensor with the same shape as the input x. """ x_shape = x.shape x = x.view(-1, x_shape[-1]) num_experts = w2_stacked_weight.shape[0] + + # Standardized on TRT-LLM format (conversion happens during graph transformation) + # TRT-LLM format: gate_up is (2*I, H), down is (H, I) intermediate_size = w3_w1_stacked_weight.shape[1] // 2 results = torch.zeros_like(x) @@ -190,12 +261,7 @@ def torch_fused_moe( w3 = stacked[:intermediate_size, :] w1 = stacked[intermediate_size:, :] w2 = w2_stacked_weight[expert_id] - - # Compute expert output: - # expert_out = (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t() - out_w1 = expert_inputs @ w1.t() - out_w3 = expert_inputs @ w3.t() - expert_out = (F.silu(out_w1) * out_w3) @ w2.t() + expert_out = (F.silu(expert_inputs @ w1.t()) * (expert_inputs @ w3.t())) @ w2.t() scaling = routing_weights[batch_idx, nth_expert].unsqueeze(-1) results[batch_idx] += scaling * expert_out diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index c7ea32fb0e2..af0865c183f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -19,10 +19,19 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: + """Replace torch MoE ops with fused backend-specific implementations. + + Handles both: + - Standard MoE (per-expert weight lists): torch_moe with apply_routing_on_input=False + - Llama4 MoE (pre-stacked weight tensors): torch_moe with apply_routing_on_input=True + + For Llama4 stacked tensors, applies routing weights to input before the fused kernel. + """ fused_key_counter = 0 graph = gm.graph backend = backend.lower() + # Map backend to fused MoE op (handles both standard and Llama4 stacked tensor patterns) replacement_op = { "auto": torch.ops.auto_deploy.trtllm_moe_fused, "trtllm": torch.ops.auto_deploy.trtllm_moe_fused, @@ -33,68 +42,51 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t if not is_op(node, torch.ops.auto_deploy.torch_moe): continue - (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") - assert backend != "triton" or mlp_style_val == "mlp", ( - "Triton backend only supports mlp style." + # Detect if this is a stacked MoE (Llama4 pattern) or per-expert list (standard pattern) + (apply_routing_val, w1_weight_list) = extract_op_args( + node, "apply_routing_on_input", "w1_weight" ) - hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = ( - extract_op_args( - node, - "x", - "selected_experts", - "routing_weights", - "w1_weight", - "w2_weight", - "w3_weight", - ) - ) - if mlp_style_val == "gated_mlp": - fused_w_up_experts = torch.stack( - [ - torch.cat( - [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2 - ) - for w1_node, w3_node in zip(w1_list, w3_list) - ], - dim=0, + # Check if it's stacked format: single-element list with 3D tensor + is_stacked_moe = False + if apply_routing_val: + # In FX graphs, w1_weight_list might be a Node representing a list() call + list_content = None + if isinstance(w1_weight_list, Node) and w1_weight_list.target is list: + # Extract from list() call node + if w1_weight_list.args: + list_content = w1_weight_list.args[0] + elif isinstance(w1_weight_list, (list, tuple)): + # Direct Python list + list_content = w1_weight_list + + # Check if it's a single-element list with a 3D tensor + if list_content is not None and len(list_content) == 1: + w1_node = list_content[0] + if isinstance(w1_node, Node) and w1_node.op == "get_attr": + try: + w1_tensor = gm.get_parameter(w1_node.target) + is_stacked_moe = w1_tensor.ndim == 3 + except (AttributeError, KeyError): + pass + + if is_stacked_moe: + # Stacked MoE (Llama4 pattern): only supports gated MLP + (act_fn_val,) = extract_op_args(node, "act_fn") + _process_llama4_stacked_moe_node( + gm, graph, node, replacement_op, act_fn_val, fused_key_counter ) - new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}" - - elif mlp_style_val == "mlp": - fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) - new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" else: - raise ValueError(f"Unknown mlp_style: {mlp_style_val}") - - fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) - - new_key_w_down = f"fused_moe_w2_stacked_{fused_key_counter}" - fused_key_counter += 1 - param_w_up = torch.nn.Parameter(fused_w_up_experts) - param_w_down = torch.nn.Parameter(fused_w_down_experts) - gm.register_parameter(new_key_w_up, param_w_up) - gm.register_parameter(new_key_w_down, param_w_down) - - with graph.inserting_before(node): - new_node = graph.call_function( - # TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models - replacement_op, - args=( - hidden_states, - selected_experts, - routing_weights, - graph.get_attr(new_key_w_up), - graph.get_attr(new_key_w_down), - ), - kwargs={ - "mlp_style": mlp_style_val, - "act_fn": act_fn_val, - }, + # Standard MoE with per-expert weight lists + (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") + assert backend != "triton" or mlp_style_val == "mlp", ( + "Triton backend only supports mlp style." + ) + _process_regular_moe_node( + gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter ) - node.replace_all_uses_with(new_node) - graph.erase_node(node) + fused_key_counter += 1 # Delete the unstacked weights immediately to save GPU memory # This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory @@ -105,6 +97,218 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t return fused_key_counter +def _process_regular_moe_node( + gm: GraphModule, + graph: torch.fx.Graph, + node: Node, + replacement_op, + mlp_style_val: str, + act_fn_val: str, + fused_key_counter: int, +) -> None: + """Process a single torch_moe node with per-expert weight lists. + + Stacks weight parameters and creates a fused MoE node. + The kernel applies routing weights to the output. + """ + hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = extract_op_args( + node, + "x", + "selected_experts", + "routing_weights", + "w1_weight", + "w2_weight", + "w3_weight", + ) + + # Stack weights based on MLP style + if mlp_style_val == "gated_mlp": + # For gated MLP, concatenate w3 and w1 then stack across experts + fused_w_up_experts = torch.stack( + [ + torch.cat( + [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], + dim=-2, + ) + for w1_node, w3_node in zip(w1_list, w3_list) + ], + dim=0, + ) + new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}" + elif mlp_style_val == "mlp": + # For regular MLP, just stack w1 + fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) + new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" + else: + raise ValueError(f"Unknown mlp_style: {mlp_style_val}") + + # Stack w2/down weights + fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) + new_key_w_down = f"fused_moe_w2_stacked_{fused_key_counter}" + + # Register the stacked weights as parameters + param_w_up = torch.nn.Parameter(fused_w_up_experts) + gm.register_parameter(new_key_w_up, param_w_up) + + param_w_down = torch.nn.Parameter(fused_w_down_experts) + gm.register_parameter(new_key_w_down, param_w_down) + + # Create fused MoE node - kernel applies routing to output + with graph.inserting_before(node): + w_up_arg = graph.get_attr(new_key_w_up) + w_down_arg = graph.get_attr(new_key_w_down) + + new_node = graph.call_function( + replacement_op, + args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg), + kwargs={ + "mlp_style": mlp_style_val, + "act_fn": act_fn_val, + }, + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def _process_llama4_stacked_moe_node( + gm: GraphModule, + graph: torch.fx.Graph, + node: Node, + replacement_op, + act_fn_val: str, + fused_key_counter: int, +) -> None: + """Process a single Llama4 MoE node with pre-stacked weight tensors. + + Only supports gated MLP (SwiGLU-style) architecture. + Converts Llama4 format weights to TRT-LLM format to standardize all downstream ops. + Applies routing weights to INPUT before the fused kernel to prevent double multiplication. + This is the Llama4 pattern where weights are already stacked across experts. + Result: silu(input * routing_weight) - routing affects activation. + """ + # torch_moe with stacked format: weights are in single-element lists + hidden_states, selected_experts, routing_weights, w1_list, w2_list = extract_op_args( + node, + "x", + "selected_experts", + "routing_weights", + "w1_weight", + "w2_weight", + ) + + # Extract the single stacked tensor from each list + # Handle both FX graph Nodes (list() calls) and direct Python lists + def extract_from_list_arg(list_arg): + if isinstance(list_arg, Node) and list_arg.target is list: + # Extract from list() call node + return list_arg.args[0][0] if list_arg.args else None + elif isinstance(list_arg, (list, tuple)): + # Direct Python list + return list_arg[0] + else: + raise ValueError(f"Unexpected list format: {type(list_arg)}") + + w3_w1_stacked = extract_from_list_arg(w1_list) + w2_stacked = extract_from_list_arg(w2_list) + + # Convert Llama4 format to TRT-LLM format if needed + # This standardizes all downstream ops to only handle TRT-LLM format + if w3_w1_stacked.op == "get_attr" and w2_stacked.op == "get_attr": + gate_up_weight = gm.get_parameter(w3_w1_stacked.target) + down_weight = gm.get_parameter(w2_stacked.target) + + # Detect format: + # - Llama4: gate_up is (E, H, 2*I) and down is (E, I, H) + # - TRT-LLM: gate_up is (E, 2*I, H) and down is (E, H, I) + # If both have H in middle dimension, they're Llama4 format + is_llama4 = gate_up_weight.shape[1] == down_weight.shape[2] + + if is_llama4: + # Convert Llama4 (E, H, 2*I) -> TRT-LLM (E, 2*I, H) + gate_up_trtllm = gate_up_weight.transpose(1, 2).contiguous() + # Convert Llama4 (E, I, H) -> TRT-LLM (E, H, I) + down_trtllm = down_weight.transpose(1, 2).contiguous() + + # Register converted weights + new_key_w_up = f"llama4_to_trtllm_w3_w1_{fused_key_counter}" + new_key_w_down = f"llama4_to_trtllm_w2_{fused_key_counter}" + + gm.register_parameter(new_key_w_up, torch.nn.Parameter(gate_up_trtllm)) + gm.register_parameter(new_key_w_down, torch.nn.Parameter(down_trtllm)) + + # Store keys to create get_attr nodes later in insertion context + needs_get_attr = True + w_up_key = new_key_w_up + w_down_key = new_key_w_down + else: + # Already TRT-LLM format, use directly + needs_get_attr = False + w_up_arg = w3_w1_stacked + w_down_arg = w2_stacked + else: + # Not get_attr nodes (might be intermediate ops), use directly + needs_get_attr = False + w_up_arg = w3_w1_stacked + w_down_arg = w2_stacked + + # Llama4 INPUT-SIDE routing: apply routing to INPUT before kernel + # Cast BOTH input and routing_weights to weight dtype if needed + # Critical: BFloat16 * Float32 → Float32 (type promotion) so we cast both to same dtype + with graph.inserting_before(node): + # Create get_attr nodes INSIDE insertion context for proper topological ordering + if needs_get_attr: + w_up_arg = graph.get_attr(w_up_key) + w_down_arg = graph.get_attr(w_down_key) + + # Get weight dtype to ensure dtype consistency for Llama4 stacked tensors + # The fused kernel requires input and weights to have matching dtypes + weight_dtype = None + if w_up_arg.op == "get_attr": + try: + weight_tensor = gm.get_parameter(w_up_arg.target) + weight_dtype = weight_tensor.dtype + except (AttributeError, KeyError): + pass + input_to_scale = hidden_states + routing_to_scale = routing_weights + + if weight_dtype is not None and weight_dtype != torch.float32: + input_to_scale = graph.call_function( + torch.ops.aten._to_copy.default, + args=(hidden_states,), + kwargs={"dtype": weight_dtype}, + ) + routing_to_scale = graph.call_function( + torch.ops.aten._to_copy.default, + args=(routing_weights,), + kwargs={"dtype": weight_dtype}, + ) + + # Scale input: hidden_states = hidden_states * routing_weights (both same dtype now) + scaled_input = graph.call_function( + torch.ops.aten.mul.Tensor, + args=(input_to_scale, routing_to_scale), + ) + + # Pass ones to kernel to prevent it from multiplying routing again (already applied) + ones_node = graph.call_function( + torch.ops.aten.ones_like.default, + args=(routing_weights,), + ) + + new_node = graph.call_function( + replacement_op, + args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg), + kwargs={ + "act_fn": act_fn_val, + }, + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]: """ Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following @@ -585,6 +789,484 @@ def scale_keys(self) -> List[str]: return ["input_scale", "weight_scale", "alpha"] +class MatchBmmMoePatternConfig(TransformConfig): + """Configuration for MatchBmmMoePattern transform.""" + + pass + + +@TransformRegistry.register("match_bmm_moe_pattern") +class MatchBmmMoePattern(BaseTransform): + """Match and fuse Llama4 MoE pattern with pre-stacked weight tensors. + + This pattern uses batch matrix multiply (BMM) operations for parallel expert computation + with weights already stacked across the expert dimension. + + Only matches patterns where topk uses k=1 (single expert per token). + """ + + config: MatchBmmMoePatternConfig + + @classmethod + def get_config_class(cls): + return MatchBmmMoePatternConfig + + @staticmethod + def _find_gate_up_bmm(final_bmm: Node) -> Optional[Tuple[Node, Node]]: + """Find the MoE gate_up BMM and chunk node from the final BMM. + + BMM MoE pattern traces back: final_bmm <- mul(up, silu(gate)) <- chunk <- first_bmm (gate_up) + + Returns: + Tuple of (first_bmm, gate_up_weight) or None if not found + """ + # Input to final bmm should be mul(up, silu(gate)) + mul_node = final_bmm.args[0] + if not isinstance(mul_node, Node) or not is_op(mul_node, torch.ops.aten.mul): + return None + + if not mul_node.args or len(mul_node.args) < 2: + return None + + # Find silu node (one of the mul inputs) + arg_a, arg_b = mul_node.args[:2] + silu_node = ( + arg_a + if is_op(arg_a, torch.ops.aten.silu) + else arg_b + if is_op(arg_b, torch.ops.aten.silu) + else None + ) + if silu_node is None: + return None + + up_node = arg_b if arg_a is silu_node else arg_a + if not isinstance(up_node, Node): + return None + + # silu input should be gate from chunk + if not silu_node.args or not isinstance(silu_node.args[0], Node): + return None + gate_node = silu_node.args[0] + + # Both gate and up come from chunk (getitem nodes) + if gate_node.op != "call_function" or up_node.op != "call_function": + return None + + # Find the chunk node + chunk_node = None + if hasattr(gate_node, "args") and gate_node.args: + potential_chunk = gate_node.args[0] + if isinstance(potential_chunk, Node) and is_op(potential_chunk, torch.ops.aten.chunk): + chunk_node = potential_chunk + + if chunk_node is None or not chunk_node.args or chunk_node.args[1] != 2: + return None + + # chunk input is the first batched BMM for Llama4 (gate_up_proj) + first_bmm = chunk_node.args[0] + if not isinstance(first_bmm, Node) or not is_op(first_bmm, torch.ops.aten.bmm): + return None + + if not first_bmm.args or len(first_bmm.args) < 2: + return None + + # Llama4: gate_up_weight is pre-stacked [num_experts, hidden, 2*intermediate] + gate_up_weight = first_bmm.args[1] + if not isinstance(gate_up_weight, Node) or gate_up_weight.op != "get_attr": + return None + + return (first_bmm, gate_up_weight) + + @staticmethod + def _find_input_and_routing(batched_input: Node) -> Optional[Tuple[Node, Node]]: + """Find the input hidden states and routing weights from batched input. + + BMM MoE pattern traces back: batched_input <- mul(repeat(input), routing) <- repeat <- input + + Only matches patterns where topk uses k=1 (single expert per token). + + Returns: + Tuple of (input_hidden_states, topk_node) or None if not found + """ + # batched_input comes from mul(repeated_input, routing_weights) + if not batched_input.args or not isinstance(batched_input.args[0], Node): + return None + + mul_routing = batched_input.args[0] + if ( + not is_op(mul_routing, torch.ops.aten.mul) + or not mul_routing.args + or len(mul_routing.args) < 2 + ): + return None + + # One arg is repeat (input), other is routing weights + repeat_node = None + routing_weight_node = None + for arg in mul_routing.args[:2]: + if isinstance(arg, Node): + if is_op(arg, torch.ops.aten.repeat): + repeat_node = arg + else: + routing_weight_node = arg + + if not repeat_node or not routing_weight_node: + return None + + # Get original input from repeat + if not repeat_node.args or not isinstance(repeat_node.args[0], Node): + return None + input_hidden_states = repeat_node.args[0] + + # Trace back from routing_weight to find topk + try: + topk_node, _ = bfs( + routing_weight_node, + lambda n: is_op(n, torch.ops.aten.topk), + attr_next="all_input_nodes", + ) + router_logits = topk_node.args[0] if topk_node.args else None + if not router_logits: + return None + + # Verify topk is using k=1 (only match single-expert-per-token routing) + if len(topk_node.args) < 2: + return None + k_value = topk_node.args[1] + if k_value != 1: + return None + + except RuntimeError: + return None + + return (input_hidden_states, topk_node) + + @staticmethod + def _find_output_and_routing_flavor(final_bmm: Node) -> Optional[Tuple[Node, bool]]: + """Find the output node and detect routing application method. + + Llama4 stacked MoE pattern traces forward: final_bmm -> view -> reshape -> sum [-> mul?] + + Returns: + Tuple of (output_node, apply_routing_on_input) or None if not found + apply_routing_on_input is True if routing is applied to input, False if applied to output + """ + # Llama4 pattern: bmm -> view([-1, hidden]) -> reshape([num_experts, -1, hidden]) -> sum(dim=0) + output_view = None + for user in final_bmm.users: + if is_op(user, torch.ops.aten.view): + output_view = user + break + + if not output_view: + return None + + # Find reshape after view + reshape_node = None + for user in output_view.users: + if is_op(user, torch.ops.aten.reshape): + reshape_node = user + break + + if not reshape_node: + return None + + # Find sum after reshape + sum_node = None + for user in reshape_node.users: + if is_op(user, torch.ops.aten.sum): + sum_node = user + break + + if not sum_node: + return None + + # Detect routing application method: check if routing is applied after sum (OUTPUT) + apply_routing_on_input = True # Default for Llama4 (routing already applied before BMM) + output_node = sum_node + + for user in sum_node.users: + if is_op(user, torch.ops.aten.mul): + # Found multiplication after sum - routing is applied to OUTPUT + apply_routing_on_input = False + output_node = user + break + + return (output_node, apply_routing_on_input) + + @staticmethod + def _match_bmm_moe_pattern( + start_boundary: Node, + end_boundary: Node, + ): + """ + Match the BMM MoE pattern (ONE pattern per layer, not per expert). + + This BMM MoE pattern uses batch matrix multiply (BMM) operations for parallel expert + computation with pre-stacked weight tensors. + + Only matches patterns where topk uses k=1 (single expert per token). + + Supports TWO routing flavors: + + 1. INPUT-SIDE routing (most common): + - Pattern: mul(input, routing) -> bmm -> silu -> bmm -> sum + - Result: silu(input * routing_weight) - routing affects activation + - Routing multiplication happens BEFORE BMM operations + + 2. OUTPUT-SIDE routing (alternative): + - Pattern: bmm -> silu -> bmm -> sum -> mul(output, routing) + - Result: silu(input) * routing_weight) - routing scales output + - Routing multiplication happens AFTER sum + + The function auto-detects which flavor is present and returns metadata. + + The BMM MoE pattern corresponds to: + repeated_input = repeat(input, [num_experts, 1]) + routing_weights = reshape(transpose(sigmoid(scatter(topk(router_logits))))) + routed_input = mul(repeated_input, routing_weights) # <-- INPUT-SIDE MULTIPLICATION + batched_input = view(routed_input, [num_experts, -1, hidden]) + gate_up = bmm(batched_input, gate_up_proj) # gate_up_proj: pre-stacked [num_experts, hidden, 2*inter] + gate, up = chunk(gate_up, 2) + output = bmm(up * silu(gate), down_proj) # down_proj: pre-stacked [num_experts, inter, hidden] + final_output = view(output, [-1, hidden]) + + Returns: + List of dicts, one per MoE layer found: + { + "input": Node, # Unmultiplied input tensor + "router_logits": Node, # Router logits tensor + "gate_up_weight": Node, # Stacked gate+up weights + "down_weight": Node, # Stacked down weights + "output": Node, # Output node to replace (sum or mul) + "topk": Node, # TopK node for routing + "apply_routing_on_input": bool, # True if routing on input, False if on output + } + """ + moe_layers = [] + + nodes = list(start_boundary.graph.nodes) + region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)] + + for node in region_nodes: + # Look for the final bmm (down_proj) - this is the BATCHED bmm + if not is_op(node, torch.ops.aten.bmm): + continue + + final_bmm = node + if not final_bmm.args or len(final_bmm.args) < 2: + continue + + # Step 1: Get down_proj weight + down_weight = final_bmm.args[1] + if not isinstance(down_weight, Node) or down_weight.op != "get_attr": + continue + + # Step 2: Find the first BMM (gate_up) by tracing back through chunk and mul + result = MatchBmmMoePattern._find_gate_up_bmm(final_bmm) + if result is None: + continue + first_bmm, gate_up_weight = result + + # Step 3: Get batched input and trace back to original input and routing + batched_input = first_bmm.args[0] + if not isinstance(batched_input, Node) or not is_op(batched_input, torch.ops.aten.view): + continue + + result = MatchBmmMoePattern._find_input_and_routing(batched_input) + if result is None: + continue + input_hidden_states, topk_node = result + + # Get router_logits for metadata + router_logits = topk_node.args[0] if topk_node.args else None + if not router_logits: + continue + + # Step 4: Find output node and detect routing application method + result = MatchBmmMoePattern._find_output_and_routing_flavor(final_bmm) + if result is None: + continue + output_node, apply_routing_on_input = result + + # Step 5: Add the matched layer + moe_layers.append( + { + "input": input_hidden_states, + "router_logits": router_logits, + "gate_up_weight": gate_up_weight, + "down_weight": down_weight, + "output": output_node, + "topk": topk_node, + "apply_routing_on_input": apply_routing_on_input, + } + ) + + return moe_layers + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph. + boundary_nodes = identify_regions_between_residuals(gm) + + num_moe_patterns = 0 + + for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): + # Step 1: Match BMM MoE patterns (one pattern per MoE layer) + moe_layers = self._match_bmm_moe_pattern(start_boundary, end_boundary) + + if not moe_layers: + continue + + # Process each MoE layer + for layer_info in moe_layers: + input_hidden_states = layer_info["input"] + gate_up_weight = layer_info["gate_up_weight"] + down_weight = layer_info["down_weight"] + output_node = layer_info["output"] + topk_node = layer_info["topk"] + # Get routing application method from pattern matcher + # Default to True (apply on input) which is the common Llama4 pattern + input_routing = layer_info.get("apply_routing_on_input", True) + + # Step 2: Extract routing information + # selected_experts: topk indices [tokens, top_k] + # routing_weights: topk values [tokens, top_k] + selected_experts = None + routing_weights_node = None + + for user in topk_node.users: + if ( + user.op == "call_function" + and hasattr(user.target, "__name__") + and user.target.__name__ == "getitem" + ): + if len(user.args) >= 2: + if user.args[1] == 1: # indices + selected_experts = user + elif user.args[1] == 0: # values + # For the fused MoE op, we use topk values directly + # (shape: tokens x top_k), NOT the scattered version + routing_weights_node = user + + if not selected_experts: + continue + + if not routing_weights_node: + continue + + # Find scatter and sigmoid nodes - we need the sigmoid output for routing weights! + # The original pattern: sigmoid(scatter(topk_values)) normalizes routing to [0,1] + # The fused op must use the same normalized values, not raw topk logits. + scatter_node = None + sigmoid_node = None + + # Find scatter operation + for user in routing_weights_node.users: + if is_op(user, {torch.ops.aten.scatter, torch.ops.aten.scatter_}): + scatter_node = user + # Check if sigmoid exists after scatter (may have to.dtype in between) + current = user + for _ in range(2): # Allow up to 2 hops + for next_user in current.users: + if is_op(next_user, torch.ops.aten.sigmoid): + sigmoid_node = next_user + break + elif is_op(next_user, torch.ops.aten.to): + current = next_user + break + if sigmoid_node: + break + break + + if not scatter_node: + continue + + if not sigmoid_node: + continue + + # Extract normalized routing weights from the sigmoid output + # The sigmoid output has shape [tokens, num_experts] (scattered) + # We need to gather it back to [tokens, top_k] using selected_experts indices + # This gives us sigmoid(scatter(topk_values)) in the compact [tokens, top_k] format + graph = gm.graph + with graph.inserting_after(sigmoid_node): + # Create gather operation: routing_weights_normalized = sigmoid_output.gather(1, selected_experts) + routing_weights_normalized = graph.call_function( + torch.ops.aten.gather, + args=(sigmoid_node, 1, selected_experts), + ) + + # Use the normalized routing weights instead of raw topk values + routing_weights_node = routing_weights_normalized + + # Step 4: Apply MoE fusion + # If input_routing is True: kernel applies routing to input + # If input_routing is False: kernel applies routing to output + apply_routing_on_input = input_routing + + # Wrap stacked tensors in single-element lists for torch_moe unified interface + with graph.inserting_before(output_node): + # Create list nodes for stacked weights + w1_list_node = graph.call_function( + list, + args=([gate_up_weight],), + ) + w2_list_node = graph.call_function( + list, + args=([down_weight],), + ) + w3_list_node = graph.call_function( + list, + args=([],), # Empty list for stacked gated MLP + ) + + fused_moe_node = graph.call_function( + torch.ops.auto_deploy.torch_moe, + args=( + input_hidden_states, + selected_experts, + routing_weights_node, + w1_list_node, + w2_list_node, + w3_list_node, + ), + kwargs={ + "mlp_style": "gated_mlp", + "apply_routing_on_input": apply_routing_on_input, + }, + ) + + # Replace the output node with fused MoE + output_node.replace_all_uses_with(fused_moe_node) + graph.erase_node(output_node) + + # Clean up dead nodes + gm.graph.eliminate_dead_code() + + # Clean up dead inplace nodes in the region + while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): + gm.graph.eliminate_dead_code() + + # Delete unused submodules/parameters + gm.delete_all_unused_submodules() + + num_moe_patterns += 1 + + info = TransformInfo( + skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False + ) + return gm, info + + def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: """ Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters. diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 08441632c39..64698911ddb 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -6,10 +6,18 @@ from abc import ABC, abstractmethod from enum import Enum, IntEnum from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple - -if TYPE_CHECKING: - from ..transform.library.sharding import ShardingTransformConfig +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) import torch import torch.nn as nn @@ -32,6 +40,9 @@ modelopt_fp4_scale_to_cutlass_fp4_scale, ) +if TYPE_CHECKING: + from ..transform.library.sharding import ShardingTransformConfig + def validate_allreduce_strategy(v): """Convert string names like 'AUTO' to AllReduceStrategy enum. @@ -997,17 +1008,66 @@ def _insert_sharded_moe( allreduce_strategy: AllReduceStrategy, scale_names: Sequence[str] = (), ): - """Update the torch_moe node with sharded weight lists, + """Update the torch_moe node with sharded weight lists or stacked tensors, sharded `selected_experts` and `final_scales(router_logics)`. Add an all_reduce node after the moe node. + Handles both: + - Standard format: per-expert weight lists + - Stacked format: single-element lists containing stacked 3D tensors (Llama4 pattern) + NOTE: allreduce_strategy is MANDATORY. """ if allreduce_strategy is None: raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}") scale_names = list(scale_names) - num_experts = len(node.args[3]) + # Detect format: check if w1_weight is a single-element list with a 3D tensor (stacked format) + w1_weight_arg = node.args[3] + is_stacked = False + + # In FX graphs, the list might be a Node representing a list() call + if isinstance(w1_weight_arg, Node): + # Check if this is a list() call node + if w1_weight_arg.target is list and len(w1_weight_arg.args) > 0: + # Get the actual list content from the args + list_content = w1_weight_arg.args[0] + if isinstance(list_content, (list, tuple)) and len(list_content) == 1: + first_elem = list_content[0] + if isinstance(first_elem, Node) and first_elem.op == "get_attr": + try: + tensor = gm.get_parameter(first_elem.target) + is_stacked = tensor.ndim == 3 + except (AttributeError, KeyError): + pass + elif isinstance(first_elem, torch.Tensor): + is_stacked = first_elem.ndim == 3 + # Handle case where it's a direct Python list (not in FX graph context) + elif isinstance(w1_weight_arg, (list, tuple)) and len(w1_weight_arg) == 1: + first_elem = w1_weight_arg[0] + if isinstance(first_elem, Node) and first_elem.op == "get_attr": + try: + tensor = gm.get_parameter(first_elem.target) + is_stacked = tensor.ndim == 3 + except (AttributeError, KeyError): + pass + elif isinstance(first_elem, torch.Tensor): + is_stacked = first_elem.ndim == 3 + + if is_stacked: + # Use stacked tensor sharding logic (similar to _insert_sharded_moe_bmm) + _insert_sharded_moe_stacked(gm, node, rank, world_size, allreduce_strategy, scale_names) + return + + # Standard per-expert list sharding + # For FX graphs, get the list from the Node; for direct calls, use the list directly + if isinstance(w1_weight_arg, Node) and w1_weight_arg.target is list: + # Extract the list content from the list() call node + num_experts = len(w1_weight_arg.args[0]) if w1_weight_arg.args else 0 + elif isinstance(w1_weight_arg, (list, tuple)): + num_experts = len(w1_weight_arg) + else: + raise ValueError(f"Unexpected w1_weight format in node {node.name}: {type(w1_weight_arg)}") args = list(node.args) # -- Handle selected_experts and final_scales sharding -- @@ -1082,14 +1142,255 @@ def get_partition(lst, world_size, rank): dist_node.replace_input_with(dist_node, node) -def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node: - """Return tensor_node[lo:hi, ...] via aten.slice along dim 0.""" - with gm.graph.inserting_after(tensor_node): - # aten.slice.Tensor(self, dim, start, end, step) - return gm.graph.call_function( - torch.ops.aten.slice.Tensor, - args=(tensor_node, 0, lo, hi, 1), +def _slice_expert_dim( + gm: GraphModule, + tensor_node_or_tensor: Union[Node, torch.Tensor], + lo: int, + hi: int, +) -> Union[Node, torch.Tensor]: + """Slice expert weights along dim 0 and register load hook (simple version). + + This is the original simple slicing function used by MXFP4 EP sharding. + For parameters, it modifies them in-place and returns the same node. + + Args: + gm: The graph module + tensor_node_or_tensor: Either a Node (from FX graph) or a Tensor + lo: Start index for slicing + hi: End index for slicing + + Returns: + Node or Tensor depending on input type + """ + # Handle raw tensor case + if isinstance(tensor_node_or_tensor, torch.Tensor): + return tensor_node_or_tensor[lo:hi] + + # Handle Node case + tensor_node = tensor_node_or_tensor + + if tensor_node.op != "get_attr": + # If not a parameter node, just add a runtime slice node after it + with gm.graph.inserting_after(tensor_node): + return gm.graph.call_function( + torch.ops.aten.slice.Tensor, + args=(tensor_node, 0, lo, hi, 1), + ) + + # Get the parameter + param_key = str(tensor_node.target) + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) if modname else gm + full_param = getattr(submod, param_name) + + # Slice the parameter + sliced_param = full_param[lo:hi].detach().clone() + sliced_shape = sliced_param.shape + + # Define slice function for load hook + def slice_expert_tensor(t: torch.Tensor) -> torch.Tensor: + return t[lo:hi] + + # Register load hook to slice during checkpoint loading + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=slice_expert_tensor, + param_key=param_key, + param_shape=sliced_shape, ) + ) + + # Replace the parameter with the sliced version + new_param = nn.Parameter(sliced_param, requires_grad=False) + setattr(submod, param_name, new_param) + + # Return the same node (it now points to the sliced parameter) + return tensor_node + + +def _transform_bmm_moe_weight_param( + gm: GraphModule, + param_node: Node, + lo: int, + hi: int, + swap_gate_up: bool = False, +) -> None: + """Transform a parameter for BMM MoE: slice experts, optionally swap gate/up, transpose. + + This modifies the parameter in-place and registers a load hook. + Does NOT create graph nodes - those should be created separately by the caller. + + Args: + gm: Graph module + param_node: The get_attr node for the parameter + lo: Start index for expert slicing + hi: End index for expert slicing + swap_gate_up: If True, swap W1 and W3 (Llama4 -> TRT-LLM format) + """ + if param_node.op != "get_attr": + return # Only works on parameters + + param_key = str(param_node.target) + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) if modname else gm + full_param = getattr(submod, param_name) + + # Slice the parameter along expert dimension (dim 0) + sliced_param = full_param[lo:hi].detach().clone() + + # Swap W1 and W3 if needed (for gate_up weights) + # Llama4: (E, H, 2*I) with [W1, W3], TRT-LLM wants [W3, W1] + if swap_gate_up and sliced_param.ndim == 3: + intermediate_size = sliced_param.shape[2] // 2 + w1 = sliced_param[:, :, :intermediate_size] + w3 = sliced_param[:, :, intermediate_size:] + sliced_param = torch.cat([w3, w1], dim=2) + + # Transpose: Llama4 (E, H, X) -> TRT-LLM (E, X, H) + transposed_param = sliced_param.transpose(1, 2) + transposed_shape = transposed_param.shape + + # Define transformation function for load hook + def transform_tensor(t: torch.Tensor) -> torch.Tensor: + t_sliced = t[lo:hi] + if swap_gate_up and t_sliced.ndim == 3: + intermediate_size = t_sliced.shape[2] // 2 + w1 = t_sliced[:, :, :intermediate_size] + w3 = t_sliced[:, :, intermediate_size:] + t_sliced = torch.cat([w3, w1], dim=2) + return t_sliced.transpose(1, 2).contiguous() + + # Register load hook + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=transform_tensor, + param_key=param_key, + param_shape=transposed_shape, + ) + ) + + # Replace the parameter with the transformed version + new_param = nn.Parameter(transposed_param, requires_grad=False) + setattr(submod, param_name, new_param) + + +def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int: + """Helper to get the first dimension size of an argument (Node or Tensor).""" + if isinstance(arg, torch.Tensor): + return arg.shape[0] + if isinstance(arg, Node): + if arg.op == "get_attr": + # Traverse attributes to find the tensor + obj = gm + for atom in arg.target.split("."): + obj = getattr(obj, atom) + return obj.shape[0] + if "val" in arg.meta: + return arg.meta["val"].shape[0] + raise ValueError(f"Cannot determine shape[0] for {arg}") + + +def _insert_sharded_moe_stacked( + gm: GraphModule, + node: Node, + rank: int, + world_size: int, + allreduce_strategy: AllReduceStrategy, + scale_names: Sequence[str] = (), +): + """Update the torch_moe node with sliced stacked weight tensors, + sharded `selected_experts` and `final_scales(router_logics)`. + Add an all_reduce node after the moe node. + + For torch_moe with stacked tensor format (single-element lists containing 3D tensors). + + NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. + """ + if allreduce_strategy is None: + raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}") + + # Extract the stacked tensors from single-element lists + # args[3] = w1_weight (Node representing list with one 3D tensor, or direct list) + # args[4] = w2_weight (Node representing list with one 3D tensor, or direct list) + + # Helper to extract tensor node from list (handles both Node and direct list) + def extract_tensor_from_list_arg(list_arg): + if isinstance(list_arg, Node) and list_arg.target is list: + # It's a list() call node - extract from its args + return list_arg.args[0][0] # args[0] is the list content, [0] is first element + elif isinstance(list_arg, (list, tuple)): + # Direct list + return list_arg[0] + else: + raise ValueError(f"Unexpected list format: {type(list_arg)}") + + w3_w1_tensor_node = extract_tensor_from_list_arg(node.args[3]) + w2_tensor_node = extract_tensor_from_list_arg(node.args[4]) + num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_node) + + args = list(node.args) + + # -- Handle selected_experts and final_scales sharding -- + selected_experts = args[1] + final_scales = args[2] + + experts_per_rank = num_experts // world_size + + with gm.graph.inserting_before(node): + lower = experts_per_rank * rank + # selected_experts_local = selected_experts - low + selected_experts_local = gm.graph.create_node( + "call_function", operator.sub, args=(selected_experts, lower), kwargs={} + ) + + # For num_experts % world_size != 0 case, + # assign the last (num_experts % world_size) experts to the last rank + div_node = gm.graph.create_node( + "call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={} + ) + + comp_op = torch.ge if rank == world_size - 1 else torch.eq + rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={}) + + # final_scales_local = final_scales * rank_mask + final_scales_local = gm.graph.create_node( + "call_function", operator.mul, args=(final_scales, rank_mask), kwargs={} + ) + + # -- Transform expert weight parameters -- + local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank) + + # Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H) + if isinstance(w3_w1_tensor_node, Node): + _transform_bmm_moe_weight_param( + gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True + ) + + # Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I) + if isinstance(w2_tensor_node, Node): + _transform_bmm_moe_weight_param(gm, w2_tensor_node, local_lo, local_hi, swap_gate_up=False) + + # -- Update args (keep same lists/nodes, just with transformed parameters) -- + args[1] = selected_experts_local + args[2] = final_scales_local + # args[3] and args[4] stay the same - we modified the parameters in-place + + ad_logger.debug( + f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." + ) + + node.args = tuple(args) + + # -- add an all_reduce node -- + with gm.graph.inserting_after(node): + dist_node = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_reduce.default, + args=(node, allreduce_strategy.name), + ) + node.replace_all_uses_with(dist_node) + dist_node.replace_input_with(dist_node, node) def _split_range_last_remainder(n: int, world_size: int, rank: int): diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 6b7c79de051..8112889a870 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -157,3 +157,51 @@ def test_sharding_pattern_detection(world_size: int, num_experts: int): rank=0, world_size=world_size, ) + + +def test_llama4_stacked_moe_pattern_detection(): + """Minimal test: verify torch_moe with stacked format is detected for EP sharding.""" + # Create a simple graph with torch_moe node using stacked tensor format + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + graph = gm.graph + + with graph.inserting_after(): + x = graph.placeholder("x") + selected_experts = graph.placeholder("selected_experts") + routing_weights = graph.placeholder("routing_weights") + w3_w1 = graph.placeholder("w3_w1_stacked") + w2 = graph.placeholder("w2_stacked") + + # Create single-element lists for stacked tensor format + w1_list = graph.call_function(list, args=([w3_w1],)) + w2_list = graph.call_function(list, args=([w2],)) + w3_list = graph.call_function(list, args=([],)) + + moe_node = graph.call_function( + torch.ops.auto_deploy.torch_moe, + args=(x, selected_experts, routing_weights, w1_list, w2_list, w3_list), + kwargs={"mlp_style": "gated_mlp", "apply_routing_on_input": True}, + ) + graph.output(moe_node) + + graph.lint() + gm.recompile() + + # Run pattern detection for EP + optimizer = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "use_sharding_from_factory": False, + }, + }, + ) + optimizer.shared_config.local_rank = 0 + optimizer.shared_config.world_size = 2 + _ = optimizer(None, gm) + + # Verify torch_moe with stacked format is detected for EP sharding + detected = optimizer.shared_config.sharding_transform_container.ep_transforms + assert len(detected) == 1, f"Expected 1 EP transform, got {len(detected)}" + assert detected[0].target_node == moe_node.name diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index da99e08bbe2..99fccfab304 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -1,7 +1,7 @@ import pytest import torch import torch.nn.functional as F -from _torch.helpers import reference_moe_torch +from _torch.helpers import reference_bmm_moe_torch, reference_moe_torch from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 @@ -62,6 +62,48 @@ def setup_moe_test(dtype, num_experts): ) +def setup_bmm_moe_test(dtype, num_experts): + """Setup for stacked MoE with topk=1 in TRT-LLM format.""" + SEQ_LEN = 8 + HIDDEN_SIZE = 64 + INTERMEDIATE_SIZE = 32 + NUM_EXPERTS = num_experts + TOP_K = 1 # Llama4 stacked pattern requires topk=1 + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + x = torch.rand(SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 + + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=torch.float32).cuda() + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + final_scales, selected_experts = torch.topk(routing_weights, TOP_K, dim=-1) + final_scales = final_scales / final_scales.sum(dim=-1, keepdim=True) + final_scales = final_scales.to(x.dtype) + + # TRT-LLM format: gate_up is (2*I, H), down is (H, I) + w3_w1_stacked_weight = torch.empty( + (NUM_EXPERTS, INTERMEDIATE_SIZE * 2, HIDDEN_SIZE), dtype=dtype + ).cuda() + w2_stacked_weight = torch.empty( + (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype + ).cuda() + + for expert_id in range(NUM_EXPERTS): + w31 = torch.rand(INTERMEDIATE_SIZE * 2, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1 + w2 = torch.rand(HIDDEN_SIZE, INTERMEDIATE_SIZE, dtype=dtype).cuda() * 0.1 + # TRT-LLM format: concat w3 and w1 along intermediate dim + w3_w1_stacked_weight.data[expert_id].copy_(w31) # (2*I, H) + w2_stacked_weight.data[expert_id].copy_(w2) # (H, I) + + return ( + x, + selected_experts, + final_scales, + w3_w1_stacked_weight, + w2_stacked_weight, + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_moe_op_run(dtype): num_experts = 3 @@ -109,6 +151,62 @@ def test_moe_op_run(dtype): torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_bmm_based_moe_op_run(dtype): + num_experts = 3 + ( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) = setup_bmm_moe_test(dtype, num_experts) + + with torch.inference_mode(): + x = final_scales * x + selected_experts = torch.ones_like(selected_experts) + # Use torch_moe with stacked tensor format (single-element lists) + output_torch_moe = torch.ops.auto_deploy.torch_moe( + x, + selected_experts, + final_scales, + [fused_w3_w1_stacked_weight], # Wrap in list for unified interface + [fused_w2_weight], # Wrap in list for unified interface + [], # Empty w3_weight list for stacked gated MLP + mlp_style="gated_mlp", + act_fn="silu", + apply_routing_on_input=True, + ) + output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) + output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + ) + ref_output = reference_bmm_moe_torch( + x, + selected_experts, + final_scales, + fused_w3_w1_stacked_weight, + fused_w2_weight, + apply_routing_on_input=True, + ) + + torch.cuda.synchronize() + torch.testing.assert_close(output_trt_fused_moe, output_torch_fused_moe, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(output_trt_fused_moe, ref_output, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(output_torch_fused_moe, ref_output, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support") def test_fp8_moe_op_run(dtype): diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 163309cb1f4..ae7aa5b1c52 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -93,9 +93,12 @@ def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype): return atol -def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, - final_scales: torch.Tensor, num_experts: int, - weights: Dict[str, torch.Tensor]) -> torch.Tensor: +def reference_moe_torch(x: torch.Tensor, + selected_experts: torch.Tensor, + final_scales: torch.Tensor, + num_experts: int, + weights: Dict[str, torch.Tensor], + apply_routing_on_input: bool = False) -> torch.Tensor: # cast back to the input dtype results = torch.zeros_like(x) @@ -106,9 +109,63 @@ def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, w2_weight = weights[f"{expert_id}.w2.weight"] w3_weight = weights[f"{expert_id}.w3.weight"] expert_inputs = x[batch_idx] + if apply_routing_on_input: + expert_inputs = expert_inputs * final_scales[batch_idx, nth_expert, + None] output = (F.silu(expert_inputs @ w1_weight.t()) * (expert_inputs @ w3_weight.t())) @ w2_weight.t() - results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output + if not apply_routing_on_input: + output = output * final_scales[batch_idx, nth_expert, None] + results[batch_idx] += output + + return results.view_as(x) + + +def reference_bmm_moe_torch( + x: torch.Tensor, + selected_experts: torch.Tensor, + final_scales: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, + apply_routing_on_input: bool = True) -> torch.Tensor: + """Reference for stacked MoE in TRT-LLM format. + + Args: + x: (seq_len, hidden_size) + selected_experts: (seq_len, topk) + final_scales: (seq_len, topk) + w3_w1_stacked_weight: (num_experts, 2*intermediate_size, hidden_size) - TRT-LLM format + w2_stacked_weight: (num_experts, hidden_size, intermediate_size) - TRT-LLM format + """ + num_experts = w3_w1_stacked_weight.shape[0] + intermediate_size = w3_w1_stacked_weight.shape[1] // 2 + results = torch.zeros_like(x) + + # Loop over experts (matches reference_moe_torch pattern) + for expert_id in range(num_experts): + batch_idx, nth_expert = torch.where(selected_experts == expert_id) + if len(batch_idx) == 0: + continue + + # Get weights for this expert (TRT-LLM format) + gate_up = w3_w1_stacked_weight[expert_id] # (2*I, H) + w3_weight = gate_up[:intermediate_size, :] # (I, H) + w1_weight = gate_up[intermediate_size:, :] # (I, H) + w2_weight = w2_stacked_weight[expert_id] # (H, I) + + expert_inputs = x[batch_idx] + if apply_routing_on_input: + expert_inputs = expert_inputs * final_scales[batch_idx, nth_expert, + None] + + # Gated MLP computation - TRT-LLM format uses .t() + output = (F.silu(expert_inputs @ w1_weight.t()) * + (expert_inputs @ w3_weight.t())) @ w2_weight.t() + + if not apply_routing_on_input: + output = output * final_scales[batch_idx, nth_expert, None] + + results[batch_idx] += output return results.view_as(x)