-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[None][feat] Add bf16 trtllm-gen moe support through flashinfer. #12738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,14 +63,23 @@ def get_moe_cls( | |
| return CutlassFusedMoE | ||
| return DenseGEMMFusedMoE | ||
| elif moe_backend.upper() == "TRTLLM": | ||
| if quant_config is not None and ( | ||
| quant_config.quant_mode.has_fp8_block_scales() | ||
| or quant_config.quant_mode.has_nvfp4() | ||
| or quant_config.quant_mode.has_w4a16_mxfp4() | ||
| or quant_config.quant_mode.has_w4a8_nvfp4_fp8() | ||
| or quant_config.quant_mode.has_w4a8_mxfp4_fp8() | ||
| or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): | ||
| has_quant = quant_config is not None and quant_config.quant_mode.has_any_quant( | ||
| exclude_kv_cache=True) | ||
| if has_quant and (quant_config.quant_mode.has_fp8_block_scales() | ||
| or quant_config.quant_mode.has_nvfp4() | ||
| or quant_config.quant_mode.has_w4a16_mxfp4() | ||
| or quant_config.quant_mode.has_w4a8_nvfp4_fp8() | ||
| or quant_config.quant_mode.has_w4a8_mxfp4_fp8() | ||
| or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): | ||
| return TRTLLMGenFusedMoE | ||
| if not has_quant and model_config.pretrained_config is not None and getattr( | ||
| model_config.pretrained_config, "torch_dtype", | ||
| None) == torch.bfloat16: | ||
| if TRTLLMGenFusedMoE._is_flashinfer_fused_moe_available(): | ||
| return TRTLLMGenFusedMoE | ||
| raise RuntimeError( | ||
| "TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE with " | ||
| "trtllm_bf16_moe support, but it is not available.") | ||
|
Comment on lines
+75
to
+82
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honor call-site
Also applies to: 97-113 🤖 Prompt for AI Agents |
||
| else: | ||
| logger.warning( | ||
| "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, and w4a8_mxfp4_mxfp8. " | ||
|
|
@@ -85,6 +94,25 @@ def get_moe_cls( | |
| raise ValueError(f"Unsupported moe backend: {moe_backend}") | ||
|
|
||
|
|
||
| def resolve_moe_cls( | ||
| model_config: ModelConfig, | ||
| routing_method: BaseMoeRoutingMethod, | ||
| dtype: Optional[torch.dtype], | ||
| override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]: | ||
| moe_cls = get_moe_cls(model_config, override_quant_config) | ||
|
|
||
| effective_quant_config = override_quant_config or model_config.quant_config | ||
| has_quant = (effective_quant_config is not None | ||
| and effective_quant_config.layer_quant_mode.has_any_quant( | ||
| exclude_kv_cache=True)) | ||
| if (moe_cls == TRTLLMGenFusedMoE and not has_quant | ||
| and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method( | ||
| routing_method)): | ||
| return CutlassFusedMoE | ||
|
|
||
| return moe_cls | ||
|
|
||
|
|
||
| def create_moe_backend( | ||
| moe_cls: Type[MoE], | ||
| routing_method: BaseMoeRoutingMethod, | ||
|
|
@@ -379,7 +407,8 @@ def create_moe( | |
| pretrained_config, 'torch_dtype'): | ||
| dtype = pretrained_config.torch_dtype | ||
|
|
||
| moe_cls = get_moe_cls(model_config, override_quant_config) | ||
| moe_cls = resolve_moe_cls(model_config, routing_method, dtype, | ||
| override_quant_config) | ||
|
|
||
| enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE", | ||
| "1") == "1" | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -38,10 +38,10 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| # isort: off | ||||||||||||||||||
| from .quantization import ( | ||||||||||||||||||
| DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod, | ||||||||||||||||||
| NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, | ||||||||||||||||||
| W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, | ||||||||||||||||||
| W4A16MXFP4TRTLLMGenFusedMoEMethod) | ||||||||||||||||||
| BF16TRTLLMGenFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethod, | ||||||||||||||||||
| NVFP4TRTLLMGenFusedMoEBaseMethod, NVFP4TRTLLMGenFusedMoEMethod, | ||||||||||||||||||
| W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, | ||||||||||||||||||
| W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod) | ||||||||||||||||||
| # isort: on | ||||||||||||||||||
| from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, | ||||||||||||||||||
| DefaultMoeRoutingMethod) | ||||||||||||||||||
|
|
@@ -115,7 +115,8 @@ def can_implement( | |||||||||||||||||
| - W4A8_MXFP4_FP8 | ||||||||||||||||||
| - W4A8_MXFP4_MXFP8 | ||||||||||||||||||
|
|
||||||||||||||||||
| Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16. | ||||||||||||||||||
| Unquantized BF16 path is supported only with FlashInfer fused MoE backend. | ||||||||||||||||||
| Output dtype is hardcoded to bfloat16. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| quant_algo: The quantization algorithm to check (None for unquantized) | ||||||||||||||||||
|
|
@@ -143,10 +144,16 @@ def can_implement( | |||||||||||||||||
| f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # TRTLLMGenFusedMoE does NOT support unquantized mode | ||||||||||||||||||
| if quant_algo is None: | ||||||||||||||||||
| return _warn_and_return( | ||||||||||||||||||
| "TRTLLMGenFusedMoE does not support unquantized mode") | ||||||||||||||||||
| if swiglu_gptoss_style: | ||||||||||||||||||
| return _warn_and_return( | ||||||||||||||||||
| "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters." | ||||||||||||||||||
| ) | ||||||||||||||||||
| if not cls._is_flashinfer_fused_moe_available(): | ||||||||||||||||||
| return _warn_and_return( | ||||||||||||||||||
| "TRTLLMGenFusedMoE unquantized BF16 path requires FlashInfer fused MoE " | ||||||||||||||||||
| "with trtllm_bf16_moe support.") | ||||||||||||||||||
| return True, None | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check if quant_algo is supported | ||||||||||||||||||
| if quant_algo not in cls._SUPPORTED_QUANT_ALGOS: | ||||||||||||||||||
|
|
@@ -210,7 +217,14 @@ def __init__( | |||||||||||||||||
|
|
||||||||||||||||||
| assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." | ||||||||||||||||||
|
|
||||||||||||||||||
| self.use_flashinfer = self._check_op_backend_support() | ||||||||||||||||||
| self.use_flashinfer = self._check_flashinfer_backend_support() | ||||||||||||||||||
| if (self.quant_config is None | ||||||||||||||||||
| or not self.quant_config.layer_quant_mode.has_any_quant( | ||||||||||||||||||
| exclude_kv_cache=True)) and not self.use_flashinfer: | ||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||
| "TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE. " | ||||||||||||||||||
| "Please install a FlashInfer build with trtllm_bf16_moe support." | ||||||||||||||||||
| ) | ||||||||||||||||||
| backend_name = "flashinfer" if self.use_flashinfer else "trtllm" | ||||||||||||||||||
| self.op_backend: MoEOpBackend = get_op_backend(backend_name) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -292,7 +306,43 @@ def _to_trtllm_gen_activation_type(self, | |||||||||||||||||
| else: | ||||||||||||||||||
| raise ValueError(f"Unsupported activation type: {activation_type}") | ||||||||||||||||||
|
|
||||||||||||||||||
| def _check_op_backend_support(self) -> bool: | ||||||||||||||||||
| @staticmethod | ||||||||||||||||||
| def _is_flashinfer_fused_moe_available() -> bool: | ||||||||||||||||||
| try: | ||||||||||||||||||
| from flashinfer.fused_moe import core as _core | ||||||||||||||||||
| except (ImportError, ModuleNotFoundError): | ||||||||||||||||||
| return False | ||||||||||||||||||
| return (hasattr(_core, "trtllm_bf16_moe") | ||||||||||||||||||
| and hasattr(_core, "trtllm_bf16_routed_moe")) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _is_unquantized_path(self) -> bool: | ||||||||||||||||||
| return self.quant_config is None or not self.quant_config.layer_quant_mode.has_any_quant( | ||||||||||||||||||
| exclude_kv_cache=True) | ||||||||||||||||||
|
|
||||||||||||||||||
| @staticmethod | ||||||||||||||||||
| def _supports_flashinfer_bf16_routing_method( | ||||||||||||||||||
| routing_method: BaseMoeRoutingMethod, ) -> bool: | ||||||||||||||||||
| # FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug | ||||||||||||||||||
|
Comment on lines
+323
to
+325
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the hanging indent in Flake8 is already reporting E125 here, so the lint job will remain red until the continuation line is re-indented or the closing parenthesis is moved. Minimal fix `@staticmethod`
def _supports_flashinfer_bf16_routing_method(
- routing_method: BaseMoeRoutingMethod, ) -> bool:
+ routing_method: BaseMoeRoutingMethod,
+ ) -> bool:📝 Committable suggestion
Suggested change
🧰 Tools🪛 Flake8 (7.3.0)[error] 324-324: continuation line with same indent as next logical line (E125) 🤖 Prompt for AI Agents
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: this will be addressed by flashinfer-ai/flashinfer#2911. |
||||||||||||||||||
| return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _requires_separated_routing(self) -> bool: | ||||||||||||||||||
| """Whether this backend instance expects precomputed top-k routing.""" | ||||||||||||||||||
| # FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug | ||||||||||||||||||
| return self.use_flashinfer and self._is_unquantized_path() | ||||||||||||||||||
|
|
||||||||||||||||||
| def _check_flashinfer_backend_support(self) -> bool: | ||||||||||||||||||
| # For BF16 (unquantized) path, we will use FlashInfer regardless whether | ||||||||||||||||||
| # env TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER=1 is set or not as it's the only way. | ||||||||||||||||||
| if self._is_unquantized_path(): | ||||||||||||||||||
| if not self._is_flashinfer_fused_moe_available(): | ||||||||||||||||||
| return False | ||||||||||||||||||
| if self.activation_type != ActivationType.Swiglu: | ||||||||||||||||||
| return False | ||||||||||||||||||
| if not self._supports_flashinfer_bf16_routing_method( | ||||||||||||||||||
| self.routing_method): | ||||||||||||||||||
| return False | ||||||||||||||||||
| return True | ||||||||||||||||||
|
|
||||||||||||||||||
| use_flashinfer = os.environ.get("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER", | ||||||||||||||||||
| "0") | ||||||||||||||||||
| if use_flashinfer != "1": | ||||||||||||||||||
|
|
@@ -311,8 +361,6 @@ def _check_op_backend_support(self) -> bool: | |||||||||||||||||
| if type(quant_method) is NVFP4TRTLLMGenFusedMoEBaseMethod: | ||||||||||||||||||
| return True | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.quant_config is None: | ||||||||||||||||||
| return False | ||||||||||||||||||
| mode = self.quant_config.layer_quant_mode | ||||||||||||||||||
|
|
||||||||||||||||||
| # These quant modes are never supported via op backend | ||||||||||||||||||
|
|
@@ -365,7 +413,14 @@ def select_alltoall_method_type(self) -> AlltoallMethodType: | |||||||||||||||||
| return AlltoallMethodType.NVLinkOneSided | ||||||||||||||||||
|
|
||||||||||||||||||
| def _supports_load_balancer(self) -> bool: | ||||||||||||||||||
| """TRTLLMGenFusedMoE supports load balancer.""" | ||||||||||||||||||
| """Whether separated routing (top-k outside the kernel) is used. | ||||||||||||||||||
|
|
||||||||||||||||||
| ConfigurableMoE uses this flag to decide whether routing is separated | ||||||||||||||||||
| (top-k ids/scales computed outside backend) or fused inside the kernel. | ||||||||||||||||||
| BF16 FlashInfer path always requires separated routing. | ||||||||||||||||||
| """ | ||||||||||||||||||
| if self._requires_separated_routing(): | ||||||||||||||||||
| return True | ||||||||||||||||||
| return self.use_dp and self.parallel_size > 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| @cached_property | ||||||||||||||||||
|
|
@@ -375,9 +430,17 @@ def enable_alltoall(self): | |||||||||||||||||
| return self.alltoall_method_type != AlltoallMethodType.NotEnabled | ||||||||||||||||||
|
|
||||||||||||||||||
| def _check_configs(self): | ||||||||||||||||||
| assert self.has_deepseek_fp8_block_scales \ | ||||||||||||||||||
| assert not self.has_any_quant \ | ||||||||||||||||||
| or self.has_deepseek_fp8_block_scales \ | ||||||||||||||||||
| or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \ | ||||||||||||||||||
| or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes." | ||||||||||||||||||
| or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, \ | ||||||||||||||||||
| "TRTLLMGenFusedMoE only supports bf16 (FlashInfer), fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes." | ||||||||||||||||||
|
|
||||||||||||||||||
| if not self.has_any_quant: | ||||||||||||||||||
| assert self.activation_type == ActivationType.Swiglu, \ | ||||||||||||||||||
| "TRTLLMGenFusedMoE BF16 path only supports Swiglu activation." | ||||||||||||||||||
| assert not self.bias and self.swiglu_alpha is None and self.swiglu_beta is None and self.swiglu_limit is None, \ | ||||||||||||||||||
| "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters." | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None: | ||||||||||||||||||
| assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants." | ||||||||||||||||||
|
|
@@ -405,8 +468,7 @@ def _get_quant_method(self): | |||||||||||||||||
| f"Unsupported quantization method by TRTLLMGenFusedMoE: {self.quant_config.quant_mode}" | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||
| "TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE.") | ||||||||||||||||||
| return BF16TRTLLMGenFusedMoEMethod() | ||||||||||||||||||
|
|
||||||||||||||||||
| def create_weights(self): | ||||||||||||||||||
| if self._weights_created: | ||||||||||||||||||
|
|
@@ -467,6 +529,8 @@ def quantize_input(self, x, post_quant_comm: bool = True): | |||||||||||||||||
| - scaling_vector_size is typically the group size for block-wise quantization | ||||||||||||||||||
| """ | ||||||||||||||||||
| x_sf = None | ||||||||||||||||||
| if not self.has_any_quant: | ||||||||||||||||||
| return x, x_sf | ||||||||||||||||||
| if self.has_w4a8_mxfp4_fp8: | ||||||||||||||||||
| pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] | ||||||||||||||||||
| x = torch.nn.functional.pad(x, (0, pad_size)) | ||||||||||||||||||
|
|
@@ -526,7 +590,7 @@ def quantize_input(self, x, post_quant_comm: bool = True): | |||||||||||||||||
| return x, x_sf | ||||||||||||||||||
|
|
||||||||||||||||||
| def supports_moe_output_in_alltoall_workspace(self): | ||||||||||||||||||
| return True | ||||||||||||||||||
| return self.has_any_quant and not self.use_flashinfer | ||||||||||||||||||
|
|
||||||||||||||||||
| def run_moe( | ||||||||||||||||||
| self, | ||||||||||||||||||
|
|
@@ -542,8 +606,8 @@ def run_moe( | |||||||||||||||||
| Run MoE computation with TRTLLMGen backend. | ||||||||||||||||||
|
|
||||||||||||||||||
| This method encapsulates the core MoE computation logic, handling different | ||||||||||||||||||
| quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, | ||||||||||||||||||
| w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8). | ||||||||||||||||||
| quantization schemes (bf16, fp8_block_scales, nvfp4, w4a16_mxfp4, | ||||||||||||||||||
| w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8). | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| # Standard MoE interface parameters: | ||||||||||||||||||
|
|
@@ -592,7 +656,37 @@ def run_moe( | |||||||||||||||||
| ) == 2, f"x_sf should be 2D tensor, got shape {x_sf.shape}" | ||||||||||||||||||
| x_sf = x_sf.flatten() | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.has_deepseek_fp8_block_scales: | ||||||||||||||||||
| if not self.has_any_quant: | ||||||||||||||||||
| result = self.op_backend.run_bf16_moe( | ||||||||||||||||||
| router_logits=router_logits, | ||||||||||||||||||
| routing_bias=routing_bias, | ||||||||||||||||||
| hidden_states=x, | ||||||||||||||||||
| gemm1_weights=self.w3_w1_weight, | ||||||||||||||||||
| gemm2_weights=self.w2_weight, | ||||||||||||||||||
| num_experts=self.num_slots, | ||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||
| n_group=n_group, | ||||||||||||||||||
| topk_group=topk_group, | ||||||||||||||||||
| intermediate_size=self.intermediate_size_per_partition, | ||||||||||||||||||
| local_expert_offset=self.slot_start, | ||||||||||||||||||
| local_num_experts=self.expert_size_per_partition, | ||||||||||||||||||
| routed_scaling_factor=routed_scaling_factor, | ||||||||||||||||||
| routing_method_type=self.routing_method.routing_method_type, | ||||||||||||||||||
| topk_weights=token_final_scales, | ||||||||||||||||||
| topk_ids=token_selected_experts, | ||||||||||||||||||
| gated_act_type=self._to_trtllm_gen_activation_type( | ||||||||||||||||||
| self.activation_type), | ||||||||||||||||||
| output=moe_output, | ||||||||||||||||||
| use_shuffled_weight=getattr(self.quant_method, | ||||||||||||||||||
| "use_shuffled_weight", False), | ||||||||||||||||||
| weight_layout=getattr(self.quant_method, "weight_layout", 0), | ||||||||||||||||||
| do_finalize=do_finalize, | ||||||||||||||||||
| ) | ||||||||||||||||||
| if not do_finalize: | ||||||||||||||||||
| assert not self.reduce_results, "reduce_results must be False when do_finalize is False" | ||||||||||||||||||
| return result | ||||||||||||||||||
| final_hidden_states = result | ||||||||||||||||||
| elif self.has_deepseek_fp8_block_scales: | ||||||||||||||||||
| assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False" | ||||||||||||||||||
| # fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+ | ||||||||||||||||||
| if x_sf is None: | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Config freeze/skip flags are not safely restored across all paths.
Line 193 and Line 195 force
_frozen=Trueinstead of restoring the original frozen state, and ifcreate_moe_backend(...)throws,skip_create_weights_in_initis left mutated. This can leak state into subsequent layer construction.Suggested fix
Also applies to: 240-244
🤖 Prompt for AI Agents