-
Notifications
You must be signed in to change notification settings - Fork 1.9k
DO NOT REVIEW YET #8954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DO NOT REVIEW YET #8954
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ enum class ActivationType | |
| Swiglu, | ||
| Geglu, | ||
| SwigluBias, | ||
| Relu2, | ||
| Identity, | ||
| InvalidType | ||
| }; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, | ||
| int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, | ||
| bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids, | ||
| torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens, | ||
| torch::optional<torch::Tensor> const& out_tensor) | ||
| torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size, | ||
| torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor) | ||
| { | ||
| std::lock_guard<std::mutex> lock(mMutex); | ||
| // Free the profile workspace to save memory | ||
|
|
@@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0], | ||
| "fc1_expert_weights and fc2_expert_weights must have the same number of experts."); | ||
|
|
||
| ActivationType base_activation_type = activation_type.has_value() | ||
| ? static_cast<ActivationType>(activation_type.value()) | ||
| : ActivationType::Swiglu; | ||
| if (mUseINT8WoqPerChannel) | ||
| { | ||
| // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: | ||
|
|
@@ -337,8 +340,19 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| } | ||
| else | ||
| { | ||
| TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, | ||
| "fc1_expert_weights inter size must be fc2_expert_weights inter size."); | ||
| // TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, | ||
| // "fc1_expert_weights inter size must be fc2_expert_weights inter size."); | ||
|
|
||
| if (isGatedActivation(base_activation_type)) | ||
| { | ||
| TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, | ||
| "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); | ||
| } | ||
| else | ||
| { | ||
| TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier, | ||
| "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."); | ||
| } | ||
| } | ||
|
|
||
| int experts_per_token = token_selected_experts.sizes()[1]; | ||
|
|
@@ -375,7 +389,7 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| int const num_experts_on_rank = fc2_expert_weights.sizes()[0]; | ||
| auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size); | ||
| auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); | ||
| ActivationType base_activation_type = ActivationType::Swiglu; | ||
|
|
||
| if (swiglu_alpha.has_value()) | ||
| { | ||
| CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); | ||
|
|
@@ -474,8 +488,8 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, | ||
| int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, | ||
| bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids, | ||
| torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens, | ||
| torch::optional<torch::Tensor> const& out_tensor) | ||
| torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size, | ||
| torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor) | ||
| { | ||
| std::lock_guard<std::mutex> lock(mMutex); | ||
|
|
||
|
|
@@ -541,7 +555,9 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size); | ||
| auto parallelism_config | ||
| = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank); | ||
| ActivationType base_activation_type = ActivationType::Swiglu; | ||
| ActivationType base_activation_type = activation_type.has_value() | ||
| ? static_cast<ActivationType>(activation_type.value()) | ||
| : ActivationType::Swiglu; | ||
| if (swiglu_alpha.has_value()) | ||
|
Comment on lines
+558
to
561
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same enum‑validation needed here. Mirror the activation_type validation added in runMoe to prevent invalid values in min‑latency path. Apply: - ActivationType base_activation_type = activation_type.has_value()
- ? static_cast<ActivationType>(activation_type.value())
- : ActivationType::Swiglu;
+ ActivationType base_activation_type = ActivationType::Swiglu;
+ if (activation_type.has_value()) {
+ auto act = static_cast<ActivationType>(activation_type.value());
+ switch (act) {
+ case ActivationType::Gelu:
+ case ActivationType::Relu:
+ case ActivationType::SiLu:
+ case ActivationType::Swiglu:
+ case ActivationType::Geglu:
+ case ActivationType::SwigluBias:
+ case ActivationType::Relu2:
+ case ActivationType::Identity:
+ base_activation_type = act; break;
+ default:
+ TORCH_CHECK(false, "Invalid activation_type value: ", activation_type.value());
+ }
+ } |
||
| { | ||
| CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import torch | ||
|
|
||
| from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=()) | ||
| def trtllm_fused_moe( | ||
|
|
@@ -8,6 +10,8 @@ def trtllm_fused_moe( | |
| routing_weights: torch.Tensor, | ||
| w3_w1_stacked_weight: torch.Tensor, | ||
| w2_stacked_weight: torch.Tensor, | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| x_shape = x.shape | ||
| x = x.view(-1, x_shape[-1]) | ||
|
|
@@ -16,21 +20,38 @@ def trtllm_fused_moe( | |
| selected_experts = selected_experts.to(torch.int32) | ||
| quant_scales = [] | ||
|
|
||
| # Determine activation type | ||
| mlp_style = mlp_style.lower() | ||
| act_fn = act_fn.lower() | ||
|
|
||
| activation_type = ActivationType.Swiglu | ||
| if mlp_style == "gated_mlp": | ||
| # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) | ||
| if act_fn == "silu": | ||
| # activation_type = ActivationType.Silu | ||
| activation_type = ActivationType.Swiglu # need to fix this in trtllm | ||
| else: | ||
| raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") | ||
| elif mlp_style == "mlp": | ||
| # For non-gated MLP with ReLU^2 | ||
| if act_fn == "relu2": | ||
| activation_type = ActivationType.Relu2 | ||
| else: | ||
| raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") | ||
| else: | ||
| raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") | ||
|
|
||
| return torch.ops.trtllm.fused_moe( | ||
| x, | ||
| selected_experts, | ||
| routing_weights, | ||
| w3_w1_stacked_weight, | ||
| None, # w3_w1_stacked_bias | ||
| w2_stacked_weight, | ||
| None, # w2_stacked_bias | ||
| x.dtype, | ||
| quant_scales, | ||
| tp_size=1, | ||
| tp_rank=0, | ||
| ep_size=1, | ||
| ep_rank=0, | ||
| enable_alltoall=False, | ||
| fc1_expert_weights=w3_w1_stacked_weight, | ||
| fc1_expert_biases=None, | ||
| fc2_expert_weights=w2_stacked_weight, | ||
| fc2_expert_biases=None, | ||
| output_dtype=x.dtype, | ||
| quant_scales=quant_scales, | ||
| activation_type=activation_type, | ||
| )[0].view(x_shape) | ||
|
|
||
|
|
||
|
|
@@ -41,5 +62,153 @@ def trtllm_fused_moe( | |
| routing_weights: torch.Tensor, | ||
| w3_w1_stacked_weight: torch.Tensor, | ||
| w2_stacked_weight: torch.Tensor, | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(x) | ||
|
|
||
|
|
||
| # Todo: refactor this repeating code block | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use some TRTLLM kernels to do this quantization, e.g. |
||
| def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
| """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" | ||
| FP8_MIN = torch.finfo(torch.float8_e4m3fn).min | ||
| FP8_MAX = torch.finfo(torch.float8_e4m3fn).max | ||
| return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::trtllm_quant_fp8moe_fused", mutates_args=()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: The op name looks a bit too long... |
||
| def trtllm_quant_fp8moe_fused( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights | ||
| w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights | ||
| w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp | ||
| w1_input_scale: torch.Tensor, # [E] stacked input scales | ||
| w2_input_scale: torch.Tensor, # [E] stacked input scales | ||
| w3_input_scale: torch.Tensor, # [E] or unused | ||
| w1_weight_scale: torch.Tensor, # [E] stacked weight scales | ||
| w2_weight_scale: torch.Tensor, # [E] stacked weight scales | ||
| w3_weight_scale: torch.Tensor, # [E] or unused | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| """ | ||
| TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP. | ||
| Parameters: | ||
| x: BF16/FP16 input tensor of shape (B, H) or (B, S, H) | ||
| selected_experts: Expert indices (B*S, TOP_K) | ||
| routing_weights: Routing weights (B*S, TOP_K) | ||
| w1_weight: FP8 w1 weights [E, I, H] | ||
| w2_weight: FP8 w2 weights [E, H, I] | ||
| w3_weight: FP8 w3 weights [E, I, H] (for gated_mlp) | ||
| w1_input_scale: Input scales for w1 [E] | ||
| w2_input_scale: Input scales for w2 [E] | ||
| w3_input_scale: Input scales for w3 [E] | ||
| w1_weight_scale: Weight scales for w1 [E] | ||
| w2_weight_scale: Weight scales for w2 [E] | ||
| w3_weight_scale: Weight scales for w3 [E] | ||
| mlp_style: "gated_mlp" or "mlp" | ||
| act_fn: "silu" for gated_mlp, "relu2" for mlp | ||
| """ | ||
|
|
||
| # if mlp_style != "gated_mlp": | ||
| # raise NotImplementedError("FlashInfer FP8 MoE currently only supports gated_mlp") | ||
|
|
||
| # Store original shape and flatten to 2D | ||
| x_shape = x.shape | ||
| x2d = x.view(-1, x_shape[-1]) | ||
| x_q_fp8 = _quantize_fp8(x2d, w1_input_scale) | ||
|
|
||
| # Scales are stored in float32 | ||
| w1_weight_scale = w1_weight_scale.to(torch.float32) | ||
| w2_weight_scale = w2_weight_scale.to(torch.float32) | ||
| w1_input_scale = w1_input_scale.to(torch.float32) | ||
| w2_input_scale = w2_input_scale.to(torch.float32) | ||
|
|
||
| # Prepare quant_scales for TensorRT-LLM FP8 format: | ||
| # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale] | ||
| # For gated MLP: | ||
| # - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3) | ||
| # - gemm2_act_quant_scale: 1 / w2_input_scale | ||
| # - gemm2_dequant_scale: w2_weight_scale * w2_input_scale | ||
| # - gemm1_input_dequant_scale: w1_input_scale | ||
|
|
||
| # Compute combined scales | ||
| gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() | ||
| gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) # [E] | ||
| gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() | ||
| gemm1_input_dequant = w1_input_scale.contiguous() | ||
|
Comment on lines
+121
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid broadcasting mismatch when quantizing inputs
🤖 Prompt for AI Agents |
||
|
|
||
| assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D" | ||
| assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D" | ||
| quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant] | ||
|
|
||
| # Ensure contiguous tensors | ||
| selected_experts = selected_experts.contiguous() | ||
| routing_weights = routing_weights.contiguous() | ||
|
|
||
| # Todo: refactor this repeating code block | ||
|
|
||
| # Determine activation type | ||
| mlp_style = mlp_style.lower() | ||
| act_fn = act_fn.lower() | ||
|
|
||
| activation_type = ActivationType.Swiglu | ||
| if mlp_style == "gated_mlp": | ||
| # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) | ||
| # For gated MLP, concatenate w1 and w3 | ||
| # TensorRT-LLM expects [w3, w1] concatenated | ||
| w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H] | ||
| fc1_expert_weights = w3_w1_stacked | ||
| if act_fn == "silu": | ||
| # activation_type = ActivationType.Silu | ||
| activation_type = ActivationType.Swiglu # need to fix this in trtllm | ||
| else: | ||
| raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") | ||
| elif mlp_style == "mlp": | ||
| # For non-gated MLP with ReLU^2 | ||
| fc1_expert_weights = w1_weight.contiguous() | ||
| if act_fn == "relu2": | ||
| activation_type = ActivationType.Relu2 | ||
| else: | ||
| raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") | ||
| else: | ||
| raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") | ||
|
|
||
| # Note! Outputting Float8_e4m3fn directly is not currently supported | ||
| output = torch.ops.trtllm.fused_moe( | ||
| x_q_fp8, | ||
| selected_experts, | ||
| routing_weights, | ||
| fc1_expert_weights=fc1_expert_weights, | ||
| fc1_expert_biases=None, | ||
| fc2_expert_weights=w2_weight.contiguous(), | ||
| fc2_expert_biases=None, | ||
| output_dtype=x.dtype, | ||
| quant_scales=quant_scales, | ||
| activation_type=activation_type, | ||
| ) | ||
|
|
||
| # Restore original shape | ||
| return output[0].view(x_shape) | ||
|
|
||
|
|
||
| @trtllm_quant_fp8moe_fused.register_fake | ||
| def trtllm_quant_fp8moe_fused_fake( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w1_weight: torch.Tensor, | ||
| w2_weight: torch.Tensor, | ||
| w3_weight: torch.Tensor, | ||
| w1_input_scale: torch.Tensor, | ||
| w2_input_scale: torch.Tensor, | ||
| w3_input_scale: torch.Tensor, | ||
| w1_weight_scale: torch.Tensor, | ||
| w2_weight_scale: torch.Tensor, | ||
| w3_weight_scale: torch.Tensor, | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(x) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
API change: ensure Python bindings/tests pass activation_type consistently.
New optional args added (activation_type, unpadded_hidden_size, num_valid_tokens, out_tensor). Please verify torch bindings and call sites are updated for ordering/defaults to avoid silent arg shifts.
Run to locate call sites and confirm argument ordering:
🏁 Script executed:
Length of output: 0
🏁 Script executed:
Length of output: 14071
🏁 Script executed:
Length of output: 6340
🏁 Script executed:
Length of output: 17595
Critical: Missing activation_type parameter in moe_op_cutlass.py call.
The C++
runMoesignature at line 251-265 expects 26 parameters ending with(profile_ids, activation_type, unpadded_hidden_size, num_valid_tokens, out_tensor). However, the Python wrapper call at moe_op_cutlass.py:196-203 is missingactivation_type, causing all subsequent arguments to shift left by one position:Currently passes (incorrect):
..., min_latency_mode, self.gemm_tactics, unpadded_hidden_size, tuner_num_tokens, NoneShould pass (correct):
..., min_latency_mode, self.gemm_tactics, activation_type, unpadded_hidden_size, tuner_num_tokens, NoneThe torch_custom_ops.py:241-248 call correctly includes all four new optional parameters. Update moe_op_cutlass.py line 202 to add
activation_typebeforeunpadded_hidden_size.🤖 Prompt for AI Agents