Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,30 @@ __forceinline__ __device__ float tanh_opt(float x)
#endif
}

template <typename T>
struct Relu2
{
static bool const kIsHeavy = false;

CUTLASS_HOST_DEVICE
T operator()(T threshold, T value) const
{
ReLu<T> relu_op;
multiplies<T> mul;
T val = relu_op(threshold, value);
return mul(val, val);
}

CUTLASS_HOST_DEVICE
T operator()(T value) const
{
ReLu<T> relu_op;
multiplies<T> mul;
T val = relu_op(value);
return mul(val, val);
}
};

} // namespace thread
} // namespace epilogue
} // namespace cutlass
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum class ActivationType
Swiglu,
Geglu,
SwigluBias,
Relu2,
Identity,
InvalidType
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(
case ActivationType::Identity: runGemm<cutlass_extensions::EpilogueOpDefault>(inputs, hopper_inputs); break;
case ActivationType::Swiglu: runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(inputs, hopper_inputs); break;
case ActivationType::Geglu: runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs); break;
case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break;
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
default: TLLM_THROW("Invalid activation type."); break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2307,6 +2307,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
decltype(block_scaling_type)::value>, // Geglu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>, // SwigluBias
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
decltype(block_scaling_type)::value>, // Relu2
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
decltype(block_scaling_type)::value> // Identity
Expand Down
32 changes: 24 additions & 8 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
torch::optional<torch::Tensor> const& out_tensor)
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
{
Comment on lines +262 to 264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

API change: ensure Python bindings/tests pass activation_type consistently.

New optional args added (activation_type, unpadded_hidden_size, num_valid_tokens, out_tensor). Please verify torch bindings and call sites are updated for ordering/defaults to avoid silent arg shifts.

Run to locate call sites and confirm argument ordering:


🏁 Script executed:

#!/bin/bash
set -euo pipefail
fd -a 'torch_custom_ops.py|routing.py|test_trtllm_moe.py' -tl |
  xargs -I{} bash -c 'echo ">>> {}"; nl -ba "{}" | sed -n "1,999p" | rg -n "run_moe\(|run_moe_min_latency\(" -n -C2'

Length of output: 0



🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the moeOp.cpp file and surrounding context
echo "=== Finding moeOp.cpp ==="
fd -t f 'moeOp\.cpp' 

# Search for runMoe function calls and definitions
echo -e "\n=== runMoe function definitions ==="
rg -n 'void runMoe\(|runMoe_impl\(' --type=cpp -C3

# Search for runMoe calls in Python files
echo -e "\n=== Python calls to run_moe ==="
fd -t f -e py . | xargs rg -n 'run_moe\(|run_moe_min_latency\(' -C2

# Search for torch binding registration related to moe
echo -e "\n=== Torch binding registration ==="
rg -n 'TORCH_LIBRARY.*moe|\.def.*run_moe' --type=cpp -C2

Length of output: 14071



🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the full function signatures at the specified lines
echo "=== runMoe signature (lines 250-280) ==="
sed -n '250,280p' cpp/tensorrt_llm/thop/moeOp.cpp

echo -e "\n=== runMoeMinLatency signature (lines 550-580) ==="
sed -n '550,580p' cpp/tensorrt_llm/thop/moeOp.cpp

# Get the Python call at torch_custom_ops.py line 241
echo -e "\n=== Python call in torch_custom_ops.py (lines 235-250) ==="
sed -n '235,250p' tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

# Check one of the trtllm_gen_custom_ops calls
echo -e "\n=== trtllm_gen_custom_ops.py call (lines 119-135) ==="
sed -n '119,135p' tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

# Check moe_op_cutlass.py call
echo -e "\n=== moe_op_cutlass.py call (lines 194-205) ==="
sed -n '194,205p' tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py

Length of output: 6340



🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Search all Python files for calls to run_moe or run_moe_min_latency
echo "=== All run_moe and run_moe_min_latency calls ==="
rg -n 'run_moe\(|run_moe_min_latency\(' --type=py -B2 -A8 | head -150

# Check if there's a signature definition in Python bindings
echo -e "\n=== Python method signatures in moe_op_cutlass.py ==="
sed -n '200,250p' tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py

# Check runMoeMinLatency signature
echo -e "\n=== runMoeMinLatency signature (full) ==="
rg -n 'torch::Tensor runMoeMinLantency\(' --type=cpp -A15

Length of output: 17595


Critical: Missing activation_type parameter in moe_op_cutlass.py call.

The C++ runMoe signature at line 251-265 expects 26 parameters ending with (profile_ids, activation_type, unpadded_hidden_size, num_valid_tokens, out_tensor). However, the Python wrapper call at moe_op_cutlass.py:196-203 is missing activation_type, causing all subsequent arguments to shift left by one position:

Currently passes (incorrect): ..., min_latency_mode, self.gemm_tactics, unpadded_hidden_size, tuner_num_tokens, None

Should pass (correct): ..., min_latency_mode, self.gemm_tactics, activation_type, unpadded_hidden_size, tuner_num_tokens, None

The torch_custom_ops.py:241-248 call correctly includes all four new optional parameters. Update moe_op_cutlass.py line 202 to add activation_type before unpadded_hidden_size.

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/thop/moeOp.cpp around lines 262-264 and moe_op_cutlass.py
around lines 196-203 (specifically line 202), the Python wrapper call is missing
the activation_type argument which shifts all subsequent parameters left; update
the call at moe_op_cutlass.py:202 to insert activation_type before
unpadded_hidden_size so the argument list matches the C++ signature (...,
min_latency_mode, self.gemm_tactics, activation_type, unpadded_hidden_size,
tuner_num_tokens, None), and ensure activation_type is obtained from the correct
source (e.g., self.activation_type or passed into the function) and has the
expected optional/torch-compatible type.

std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
Expand Down Expand Up @@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");

ActivationType base_activation_type = activation_type.has_value()
? static_cast<ActivationType>(activation_type.value())
: ActivationType::Swiglu;
if (mUseINT8WoqPerChannel)
{
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
Expand All @@ -337,8 +340,19 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
// TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
// "fc1_expert_weights inter size must be fc2_expert_weights inter size.");

if (isGatedActivation(base_activation_type))
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier,
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
}
}

int experts_per_token = token_selected_experts.sizes()[1];
Expand Down Expand Up @@ -375,7 +389,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
int const num_experts_on_rank = fc2_expert_weights.sizes()[0];
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
ActivationType base_activation_type = ActivationType::Swiglu;

if (swiglu_alpha.has_value())
{
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
Expand Down Expand Up @@ -474,8 +488,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
torch::optional<torch::Tensor> const& out_tensor)
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
{
std::lock_guard<std::mutex> lock(mMutex);

Expand Down Expand Up @@ -541,7 +555,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
auto parallelism_config
= kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank);
ActivationType base_activation_type = ActivationType::Swiglu;
ActivationType base_activation_type = activation_type.has_value()
? static_cast<ActivationType>(activation_type.value())
: ActivationType::Swiglu;
if (swiglu_alpha.has_value())
Comment on lines +558 to 561
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same enum‑validation needed here.

Mirror the activation_type validation added in runMoe to prevent invalid values in min‑latency path.

Apply:

-        ActivationType base_activation_type = activation_type.has_value()
-            ? static_cast<ActivationType>(activation_type.value())
-            : ActivationType::Swiglu;
+        ActivationType base_activation_type = ActivationType::Swiglu;
+        if (activation_type.has_value()) {
+            auto act = static_cast<ActivationType>(activation_type.value());
+            switch (act) {
+            case ActivationType::Gelu:
+            case ActivationType::Relu:
+            case ActivationType::SiLu:
+            case ActivationType::Swiglu:
+            case ActivationType::Geglu:
+            case ActivationType::SwigluBias:
+            case ActivationType::Relu2:
+            case ActivationType::Identity:
+                base_activation_type = act; break;
+            default:
+                TORCH_CHECK(false, "Invalid activation_type value: ", activation_type.value());
+            }
+        }

{
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
Expand Down
191 changes: 180 additions & 11 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType


@torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=())
def trtllm_fused_moe(
Expand All @@ -8,6 +10,8 @@ def trtllm_fused_moe(
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
x_shape = x.shape
x = x.view(-1, x_shape[-1])
Expand All @@ -16,21 +20,38 @@ def trtllm_fused_moe(
selected_experts = selected_experts.to(torch.int32)
quant_scales = []

# Determine activation type
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()

activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
if act_fn == "silu":
# activation_type = ActivationType.Silu
activation_type = ActivationType.Swiglu # need to fix this in trtllm
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
# For non-gated MLP with ReLU^2
if act_fn == "relu2":
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

return torch.ops.trtllm.fused_moe(
x,
selected_experts,
routing_weights,
w3_w1_stacked_weight,
None, # w3_w1_stacked_bias
w2_stacked_weight,
None, # w2_stacked_bias
x.dtype,
quant_scales,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
enable_alltoall=False,
fc1_expert_weights=w3_w1_stacked_weight,
fc1_expert_biases=None,
fc2_expert_weights=w2_stacked_weight,
fc2_expert_biases=None,
output_dtype=x.dtype,
quant_scales=quant_scales,
activation_type=activation_type,
)[0].view(x_shape)


Expand All @@ -41,5 +62,153 @@ def trtllm_fused_moe(
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)


# Todo: refactor this repeating code block
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use some TRTLLM kernels to do this quantization, e.g. torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor

def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear)."""
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)


@torch.library.custom_op("auto_deploy::trtllm_quant_fp8moe_fused", mutates_args=())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The op name looks a bit too long...

def trtllm_quant_fp8moe_fused(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp
w1_input_scale: torch.Tensor, # [E] stacked input scales
w2_input_scale: torch.Tensor, # [E] stacked input scales
w3_input_scale: torch.Tensor, # [E] or unused
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # [E] or unused
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
"""
TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
Parameters:
x: BF16/FP16 input tensor of shape (B, H) or (B, S, H)
selected_experts: Expert indices (B*S, TOP_K)
routing_weights: Routing weights (B*S, TOP_K)
w1_weight: FP8 w1 weights [E, I, H]
w2_weight: FP8 w2 weights [E, H, I]
w3_weight: FP8 w3 weights [E, I, H] (for gated_mlp)
w1_input_scale: Input scales for w1 [E]
w2_input_scale: Input scales for w2 [E]
w3_input_scale: Input scales for w3 [E]
w1_weight_scale: Weight scales for w1 [E]
w2_weight_scale: Weight scales for w2 [E]
w3_weight_scale: Weight scales for w3 [E]
mlp_style: "gated_mlp" or "mlp"
act_fn: "silu" for gated_mlp, "relu2" for mlp
"""

# if mlp_style != "gated_mlp":
# raise NotImplementedError("FlashInfer FP8 MoE currently only supports gated_mlp")

# Store original shape and flatten to 2D
x_shape = x.shape
x2d = x.view(-1, x_shape[-1])
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale)

# Scales are stored in float32
w1_weight_scale = w1_weight_scale.to(torch.float32)
w2_weight_scale = w2_weight_scale.to(torch.float32)
w1_input_scale = w1_input_scale.to(torch.float32)
w2_input_scale = w2_input_scale.to(torch.float32)

# Prepare quant_scales for TensorRT-LLM FP8 format:
# [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
# For gated MLP:
# - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
# - gemm2_act_quant_scale: 1 / w2_input_scale
# - gemm2_dequant_scale: w2_weight_scale * w2_input_scale
# - gemm1_input_dequant_scale: w1_input_scale

# Compute combined scales
gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze()
gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) # [E]
gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze()
gemm1_input_dequant = w1_input_scale.contiguous()
Comment on lines +121 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid broadcasting mismatch when quantizing inputs

x2d is [num_tokens, hidden_size], while w1_input_scale is documented and passed as [num_experts], so _quantize_fp8(x2d, w1_input_scale) throws “The size of tensor a (…) must match the size of tensor b (…)” as soon as you have more than one expert. The CUTLASS kernel expects the activations to be quantized already, but that scale must be per token (or scalar), not per expert. Please restructure the quantization so you either (a) accept pre-quantized inputs and skip this step, or (b) postpone activation quantization until after tokens are dispatched per expert so the shapes align. Right now the path hard-crashes for real FP8 configs.(docs.flashinfer.ai)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py around
lines 121-141, the call _quantize_fp8(x2d, w1_input_scale) fails when
w1_input_scale is per-expert ([num_experts]) while x2d is [num_tokens,
hidden_size]; remove or defer this quantization and instead perform activation
quantization after tokens are dispatched per-expert (or accept pre-quantized
inputs). Concretely: do not call _quantize_fp8 with a per-expert scale here;
either (A) trust callers to provide x2d already FP8 and skip quantization, or
(B) postpone quantizing activations until after the dispatcher splits tokens to
experts and apply _quantize_fp8 with the matching per-expert scale for each
expert’s token slice (or collapse scale to a scalar/per-token shape if
intended). Also ensure the prepared quant_scales (gemm1_dequant,
gemm2_act_quant, gemm2_dequant, gemm1_input_dequant) are computed from the same
scale convention you choose (per-expert vs scalar) and have shapes that match
the kernels.


assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D"
assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D"
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant]

# Ensure contiguous tensors
selected_experts = selected_experts.contiguous()
routing_weights = routing_weights.contiguous()

# Todo: refactor this repeating code block

# Determine activation type
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()

activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
# For gated MLP, concatenate w1 and w3
# TensorRT-LLM expects [w3, w1] concatenated
w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H]
fc1_expert_weights = w3_w1_stacked
if act_fn == "silu":
# activation_type = ActivationType.Silu
activation_type = ActivationType.Swiglu # need to fix this in trtllm
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
# For non-gated MLP with ReLU^2
fc1_expert_weights = w1_weight.contiguous()
if act_fn == "relu2":
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

# Note! Outputting Float8_e4m3fn directly is not currently supported
output = torch.ops.trtllm.fused_moe(
x_q_fp8,
selected_experts,
routing_weights,
fc1_expert_weights=fc1_expert_weights,
fc1_expert_biases=None,
fc2_expert_weights=w2_weight.contiguous(),
fc2_expert_biases=None,
output_dtype=x.dtype,
quant_scales=quant_scales,
activation_type=activation_type,
)

# Restore original shape
return output[0].view(x_shape)


@trtllm_quant_fp8moe_fused.register_fake
def trtllm_quant_fp8moe_fused_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor,
w2_weight: torch.Tensor,
w3_weight: torch.Tensor,
w1_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
w3_input_scale: torch.Tensor,
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)
Loading
Loading