Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout

from tensorrt_llm._torch.auto_deploy.enums import ActivationFunction
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter

except Exception as _e:
Expand Down Expand Up @@ -122,7 +123,9 @@ def _run_mxfp4_mlp_core(
)

act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (float(alpha), float(limit)), 2
FnSpecs(ActivationFunction.SWIGLU, swiglu_fn, ("alpha", "limit")),
(float(alpha), float(limit)),
2,
)

# gate_up (with SWiGLU fused)
Expand Down
556 changes: 400 additions & 156 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
import triton
import triton.language as tl

from tensorrt_llm._torch.auto_deploy.enums import (
ActivationFunction,
MLPStyle,
act_fn_from_str,
mlp_style_from_str,
)

from ...utils.logger import ad_logger


Expand Down Expand Up @@ -605,11 +612,12 @@ def triton_fused_moe(
act_fn: str = "relu2",
) -> torch.Tensor:
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
# Convert string parameters to enums
mlp_style_enum = mlp_style_from_str(mlp_style)
act_fn_enum = act_fn_from_str(act_fn)

mlp_style = mlp_style.lower()
act_fn = act_fn.lower()
assert mlp_style == "mlp", "Triton backend only supports mlp style."
assert act_fn == "relu2", "Triton backend only supports relu2 activation."
assert mlp_style_enum == MLPStyle.MLP, "Triton backend only supports mlp style."
assert act_fn_enum == ActivationFunction.RELU2, "Triton backend only supports relu2 activation."

x_shape = x.shape
x2d = x.view(-1, x_shape[-1])
Expand All @@ -636,6 +644,8 @@ def triton_fused_moe(
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "mlp",
act_fn: str = "relu2",
) -> torch.Tensor:
return torch.empty_like(x)

Expand All @@ -661,13 +671,20 @@ def triton_quant_fp8_moe(
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # unused
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
mlp_style: str = "mlp",
act_fn: str = "relu2",
) -> torch.Tensor:
"""Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
if mlp_style != "mlp":
# Convert string parameters to enums
mlp_style_enum = mlp_style_from_str(mlp_style)
act_fn_enum = act_fn_from_str(act_fn)

if mlp_style_enum != MLPStyle.MLP:
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")

if act_fn_enum != ActivationFunction.RELU2:
raise NotImplementedError("triton_quant_fp8_moe currently supports act_fn=='relu2' only")

x_shape = x.shape
x2d = x.view(-1, x_shape[-1])

Expand Down Expand Up @@ -760,7 +777,7 @@ def triton_quant_fp8_moe(
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
mlp_style: str = "mlp",
act_fn: str = "relu2",
) -> torch.Tensor:
return torch.empty_like(x)
105 changes: 55 additions & 50 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

import torch

from tensorrt_llm._torch.auto_deploy.enums import (
ActivationFunction,
MLPStyle,
act_fn_from_str,
mlp_style_from_str,
)
from tensorrt_llm._torch.utils import ActivationType


Expand All @@ -36,25 +42,26 @@ def trtllm_moe_fused(
selected_experts = selected_experts.to(torch.int32)
quant_scales = []

# Determine activation type
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()
# Convert string parameters to enums
mlp_style_enum = mlp_style_from_str(mlp_style)
act_fn_enum = act_fn_from_str(act_fn)

# Determine activation type
activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if mlp_style_enum == MLPStyle.GATED_MLP:
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
if act_fn == "silu":
if act_fn_enum == ActivationFunction.SILU:
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
raise ValueError(f"Unsupported activation '{act_fn_enum.value}' for gated_mlp.")
elif mlp_style_enum == MLPStyle.MLP:
# For non-gated MLP with ReLU^2
if act_fn == "relu2":
if act_fn_enum == ActivationFunction.RELU2:
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
raise ValueError(f"Unsupported activation '{act_fn_enum.value}' for mlp.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
raise ValueError(f"Unknown mlp_style '{mlp_style_enum.value}'.")

return torch.ops.trtllm.fused_moe(
x,
Expand Down Expand Up @@ -93,22 +100,19 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)


def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None:
def _validate_mlp_style_and_act_fn(mlp_style: MLPStyle, act_fn: ActivationFunction) -> None:
"""Validate that mlp_style and act_fn combination is supported for TRT-LLM."""
supported_combinations = {
"gated_mlp": ["silu"],
"mlp": ["relu2"],
MLPStyle.GATED_MLP: [ActivationFunction.SILU],
MLPStyle.MLP: [ActivationFunction.RELU2],
}
supported_act_fns = [
act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list
]
assert mlp_style in supported_combinations.keys(), (
f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}."
)
assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}."
assert act_fn in supported_combinations[mlp_style], (
f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. "
f"Supported combinations: {supported_combinations}"
)

if act_fn not in supported_combinations[mlp_style]:
supported = [a.value for a in supported_combinations[mlp_style]]
raise ValueError(
f"Unsupported combination: mlp_style='{mlp_style.value}', act_fn='{act_fn.value}'. "
f"Supported activations for {mlp_style.value}: {supported}"
)


@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused", mutates_args=())
Expand Down Expand Up @@ -149,17 +153,19 @@ def trtllm_quant_fp8_moe_fused(
gemm1_dequant: Precomputed gemm1 dequant scale [E]
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
gemm2_dequant: Precomputed gemm2 dequant scale [E]
mlp_style: "gated_mlp" or "mlp"
act_fn: "silu" for gated_mlp, "relu2" for mlp
mlp_style: MLPStyle.GATED_MLP or MLPStyle.MLP
act_fn: ActivationFunction.SILU for gated_mlp, ActivationFunction.RELU2 for mlp

Non-Gated MLP:
activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()

Gated MLP:
activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
"""

_validate_mlp_style_and_act_fn(mlp_style, act_fn)
# Convert string parameters to enums and validate
mlp_style_enum = mlp_style_from_str(mlp_style)
act_fn_enum = act_fn_from_str(act_fn)
_validate_mlp_style_and_act_fn(mlp_style_enum, act_fn_enum)

# Store original shape and flatten to 2D
x_shape = x.shape
Expand Down Expand Up @@ -187,31 +193,26 @@ def trtllm_quant_fp8_moe_fused(
selected_experts = selected_experts.int().contiguous()
routing_weights = routing_weights.contiguous()

# Todo: refactor this repeating code block

# Determine activation type
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()

activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if mlp_style_enum == MLPStyle.GATED_MLP:
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
# For gated MLP, concatenate w1 and w3 as [w3, w1]
w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H]
fc1_expert_weights = w3_w1_stacked
if act_fn == "silu":
if act_fn_enum == ActivationFunction.SILU:
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
raise ValueError(f"Unsupported activation '{act_fn_enum.value}' for gated_mlp.")
elif mlp_style_enum == MLPStyle.MLP:
# For non-gated MLP with ReLU^2
fc1_expert_weights = w1_weight.contiguous()
if act_fn == "relu2":
if act_fn_enum == ActivationFunction.RELU2:
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
raise ValueError(f"Unsupported activation '{act_fn_enum.value}' for mlp.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
raise ValueError(f"Unknown mlp_style '{mlp_style_enum.value}'.")

# Note! Outputting Float8_e4m3fn directly is not currently supported
output = torch.ops.trtllm.fused_moe(
Expand Down Expand Up @@ -251,7 +252,9 @@ def trtllm_quant_fp8_moe_fused_fake(
mlp_style: str,
act_fn: str,
) -> torch.Tensor:
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
mlp_style_enum = mlp_style_from_str(mlp_style)
act_fn_enum = act_fn_from_str(act_fn)
_validate_mlp_style_and_act_fn(mlp_style_enum, act_fn_enum)
return torch.empty_like(x)


Expand Down Expand Up @@ -285,22 +288,24 @@ def trtllm_quant_nvfp4_moe_fused(

"""
NVFP4_BLOCK_SIZE = 16
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()

# Convert string parameters to enums
mlp_style_enum = mlp_style_from_str(mlp_style)
act_fn_enum = act_fn_from_str(act_fn)

activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if act_fn == "silu":
if mlp_style_enum == MLPStyle.GATED_MLP:
if act_fn_enum == ActivationFunction.SILU:
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
if act_fn == "relu2":
raise ValueError(f"Unsupported activation '{act_fn_enum.value}' for gated_mlp.")
elif mlp_style_enum == MLPStyle.MLP:
if act_fn_enum == ActivationFunction.RELU2:
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
raise ValueError(f"Unsupported activation '{act_fn_enum.value}' for mlp.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
raise ValueError(f"Unknown mlp_style '{mlp_style_enum.value}'.")

# quant_scales is described by this code:
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
Expand Down
95 changes: 95 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Shared enums for AutoDeploy.
"""

from enum import Enum


class MLPStyle(Enum):
"""MLP style for MoE layers."""

GATED_MLP = "gated_mlp" # Mixtral/DeepSeek/Llama4-style: y = W2(act(W1 x) * (W3 x))
MLP = "mlp" # NemotronH-style 2-layer: y = W_down(act(W_up x))


class ActivationFunction(Enum):
"""Activation functions for MoE layers."""

SILU = "silu" # SiLU activation
RELU2 = "relu2" # ReLU then square
SWIGLU = "swiglu" # SwiGLU activation


class WeightsFormat(Enum):
"""Weight tensor organization for MoE layers."""

PER_EXPERT = "per_expert" # Separate weight tensors per expert in lists
STACKED = "stacked" # All expert weights stacked in single tensors


class WeightsFusion(Enum):
"""Weight tensor ordering and storage for gated MLP layers."""

GATE_UP_DOWN = "w1_w2_w3_separate" # w1, w2, w3 stored separately (matches parameter order)
GATEUP_DOWN = (
"w1w3_w2" # w1 and w3 concatenated as [w1, w3], w2 separate (Llama4 native format)
)
UPGATE_DOWN = (
"w3w1_w2" # w3 and w1 concatenated as [w3, w1], w2 separate
# (TRT-LLM format, Llama4 weights swapped during load)
)


def mlp_style_from_str(s: str) -> MLPStyle:
"""Convert string to MLPStyle enum."""
s = s.lower()
for style in MLPStyle:
if style.value == s:
return style
valid_values = [style.value for style in MLPStyle]
raise ValueError(f"Unknown mlp_style '{s}'. Valid values: {valid_values}")


def act_fn_from_str(s: str) -> ActivationFunction:
"""Convert string to ActivationFunction enum."""
s = s.lower()
for act in ActivationFunction:
if act.value == s:
return act
valid_values = [act.value for act in ActivationFunction]
raise ValueError(f"Unknown act_fn '{s}'. Valid values: {valid_values}")


def weights_format_from_str(s: str) -> WeightsFormat:
"""Convert string to WeightsFormat enum."""
s = s.lower()
for fmt in WeightsFormat:
if fmt.value == s:
return fmt
valid_values = [fmt.value for fmt in WeightsFormat]
raise ValueError(f"Unknown weights_format '{s}'. Valid values: {valid_values}")


def weights_fusion_from_str(s: str) -> WeightsFusion:
"""Convert string to WeightsFusion enum."""
s = s.lower()
for fusion in WeightsFusion:
if fusion.value == s:
return fusion
valid_values = [fusion.value for fusion in WeightsFusion]
raise ValueError(f"Unknown weights_fusion '{s}'. Valid values: {valid_values}")
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def deepseek_v3_moe(self, hidden_states):
hidden_states,
selected_experts,
routing_weights,
w1_weight=[expert.gate_proj.weight for expert in self.experts],
w2_weight=[expert.down_proj.weight for expert in self.experts],
w3_weight=[expert.up_proj.weight for expert in self.experts],
weights_1=[expert.gate_proj.weight for expert in self.experts],
weights_2=[expert.down_proj.weight for expert in self.experts],
weights_3=[expert.up_proj.weight for expert in self.experts],
)

if self.config.n_shared_experts is not None:
Expand Down
Loading