-
-
Notifications
You must be signed in to change notification settings - Fork 17.9k
Enable B12x backend for non-gated MoEs (like Nemotron) #43328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,6 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.platforms import current_platform | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.utils.flashinfer import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| flashinfer_b12x_fused_moe, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| flashinfer_convert_sf_to_mma_layout, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_flashinfer_b12x_moe, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -42,6 +41,11 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Only NVFP4 (kNvfp4Static/kNvfp4Dynamic) quantization is supported. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _ACTIVATION_MAP: dict[MoEActivation, str] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MoEActivation.SILU: "silu", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MoEActivation.RELU2_NO_MUL: "relu2", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| moe_config: FusedMoEConfig, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -55,6 +59,30 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_local_experts = moe_config.num_local_experts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.ep_rank = moe_config.moe_parallel_config.ep_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Shape params for B12xMoEWrapper construction. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.global_num_experts = moe_config.num_experts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.topk = moe_config.experts_per_token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.hidden_dim = moe_config.hidden_dim | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.intermediate_size_per_partition = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| moe_config.intermediate_size_per_partition | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.max_num_tokens = moe_config.max_num_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.local_expert_offset = self.ep_rank * self.num_local_experts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| activation = moe_config.activation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if activation not in self._ACTIVATION_MAP: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"FlashInferB12xExperts does not support " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"activation {activation!r}. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Supported: {list(self._ACTIVATION_MAP.keys())}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._activation_str = self._ACTIVATION_MAP[activation] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Lazily created on first apply() call. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._wrapper: Any | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.w1_sf_mma: torch.Tensor | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.w2_sf_mma: torch.Tensor | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Normalise block scales to absorb the per-expert weight global scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (w_gs). vLLM's NVFP4 convention stores: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -124,7 +152,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _supports_no_act_and_mul() -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _supports_quant_scheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -135,7 +163,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _supports_activation(activation: MoEActivation) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return activation == MoEActivation.SILU | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return activation in (MoEActivation.SILU, MoEActivation.RELU2_NO_MUL) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -167,13 +195,29 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def expects_unquantized_inputs(self) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # b12x_fused_moe expects BF16 hidden states and performs its own FP4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # B12xMoEWrapper expects BF16 hidden states and performs its own FP4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # quantization internally. Returning True prevents the modular kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # from pre-quantizing activations, which would produce an FP4-packed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # tensor with size(-1)=k//2 and break the scale-factor conversion that | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # expects size(-1)=k. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # from pre-quantizing activations. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _ensure_wrapper(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Lazily create B12xMoEWrapper on first use.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._wrapper is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from flashinfer.fused_moe import B12xMoEWrapper | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._wrapper = B12xMoEWrapper( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_experts=self.global_num_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| top_k=self.topk, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_size=self.hidden_dim, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_size=self.intermediate_size_per_partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_cuda_graph=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_num_tokens=self.max_num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_local_experts=self.num_local_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| activation=self._activation_str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def apply( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -201,23 +245,22 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.a2_gscale is not None, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "a2_gscale must not be None for FlashInferB12xExperts" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.w1_sf_mma is not None and self.w2_sf_mma is not None, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "process_weights_after_loading must run before FlashInferB12xExperts.apply" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| top_k = topk_ids.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._ensure_wrapper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| flashinfer_b12x_fused_moe( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = self._wrapper.run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x=hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_selected_experts=topk_ids.to(torch.int32), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_final_scales=topk_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w1_weight=w1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w1_weight_sf=self.w1_sf_mma, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w1_alpha=self.g1_alphas, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fc2_input_scale=self.a2_gscale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w2_weight=w2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w2_weight_sf=self.w2_sf_mma, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w2_alpha=self.g2_alphas, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_experts=global_num_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_local_experts=self.num_local_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dtype=self.out_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output=output, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_selected_experts=topk_ids.to(torch.int32), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_final_scales=topk_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output.copy_(result) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+254
to
+266
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Anytype hint is used in the__init__method (line 82), but it has not been imported from thetypingmodule. This will cause aNameErrorat runtime when the class is instantiated. Please add the missing import.