-
-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[NVFP4] NVFP4 MOE emulation fallback for H100/MI300/MI350, standardize TritonExperts usage for OCP MX emulation
#35737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fxmarty-amd
wants to merge
60
commits into
vllm-project:main
Choose a base branch
from
fxmarty-amd:upstream-nvfp4-simulation-support-moe
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
b313689
fix issues with nvfp4 dense emulation in vllm (squash)
fxmarty-amd bc6ff39
address comments
fxmarty-amd 14bc668
nvfp4 moe emulation support
fxmarty-amd a11d131
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd 95c6a4a
wip use TritonExperts
fxmarty-amd 5a2cf8c
wip cleanup
fxmarty-amd 0ea8f82
wip cleanup
fxmarty-amd d99373e
wip cleanup
fxmarty-amd 7a5f2ba
fix activation quantization
fxmarty-amd 457f9df
address comment
fxmarty-amd 01b4dce
enable test on non-blackwell devices
fxmarty-amd 1d6c770
Merge branch 'main' into upstream-nvfp4-simulation-support-moe
fxmarty-amd cf189ef
cleanup
fxmarty-amd 6db0c7b
Merge branch 'main-upstream' into upstream-nvfp4-simulation-support-rocm
fxmarty-amd ec1f4b8
address comment
fxmarty-amd 309cefb
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd cca5040
fix
fxmarty-amd 9007357
Merge branch 'main' into upstream-nvfp4-simulation-support-rocm
fxmarty-amd e7d72f5
address bowen's comments
fxmarty-amd e3a8ebd
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd 311d47d
linting
fxmarty-amd 74e6eec
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd bf46483
use a single global scale for a2 in MOE, following flashinfer default…
fxmarty-amd 0b47522
do not modify test_blackwell_moe
fxmarty-amd 4a5c5c1
fix test and typo
fxmarty-amd 6ed0611
fix typo
fxmarty-amd 80a37f6
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd 35c88a8
simplify test
fxmarty-amd d439e80
remove outdated comment
fxmarty-amd 2d9e65c
Merge branch 'main' into upstream-nvfp4-simulation-support-rocm
fxmarty-amd c6791f7
address Michael's comments
fxmarty-amd 1fa136e
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd 56dd2bf
Merge branch 'main' into upstream-nvfp4-simulation-support-rocm
fxmarty-amd ad93d2a
linting
fxmarty-amd 0d788d8
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd e8a596f
Update vllm/model_executor/layers/quantization/compressed_tensors/sch…
fxmarty-amd c6adfe8
Update vllm/model_executor/layers/quantization/compressed_tensors/sch…
fxmarty-amd e36296a
move unsupported reasons warning in is_backend_supported
fxmarty-amd 33f118f
Merge branch 'upstream-nvfp4-simulation-support-rocm' of https://gith…
fxmarty-amd 44aadca
fix input
fxmarty-amd 3f36269
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd 911b316
addres Michael's comments
fxmarty-amd 90a54e3
simulation -> emulation
fxmarty-amd 74b9212
linting
fxmarty-amd d930b84
Merge branch 'main' into upstream-nvfp4-simulation-support-rocm
fxmarty-amd 24ec4ce
pre-commit passes locally and should not take 50min
fxmarty-amd 58439aa
Merge branch 'upstream-nvfp4-simulation-support-rocm' into upstream-n…
fxmarty-amd 8e61be3
Merge branch 'main' into upstream-nvfp4-simulation-support-moe
fxmarty-amd f2204ce
refactor OCP MX MOE emulation and address comment about moe_kernel_qu…
fxmarty-amd ca07f68
move to experts subfolder
fxmarty-amd 223c275
simplifications
fxmarty-amd d8e9283
linting
fxmarty-amd 1e1d139
Merge branch 'main' into upstream-nvfp4-simulation-support-moe
fxmarty-amd 3663f59
fix quant_dtype
fxmarty-amd adfb9da
precise comment about maybe_roundup_sizes
fxmarty-amd 9513361
add Qwen3-30B-A3B-NVFP4, Qwen3.5-35B-A3B-MXFP4-TP2 to gfx942 tests
fxmarty-amd c06e387
Merge branch 'main' into upstream-nvfp4-simulation-support-moe
fxmarty-amd df32bf3
Merge branch 'upstream-nvfp4-simulation-support-moe' of https://githu…
fxmarty-amd 4e7ab24
Merge branch 'main' into upstream-nvfp4-simulation-support-moe
fxmarty-amd bfc4f90
address comment
fxmarty-amd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
164 changes: 164 additions & 0 deletions
164
vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| NVFP4 quantization emulation for MoE. | ||
|
|
||
| This file implements NVFP4 emulation for NVFP4 MOE in case the hardware used does not | ||
| natively support NVFP4 MOE. | ||
|
|
||
| Weights are dequantized on the fly during each forward, we fall back to calling | ||
| `TritonExperts` using BF16, and fake NVFP4 quantize-dequantize | ||
| is applied on `a13`, `a2`. | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.fused_moe.activation import MoEActivation | ||
| from vllm.model_executor.layers.fused_moe.config import ( | ||
| FusedMoEConfig, | ||
| FusedMoEQuantConfig, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts | ||
| from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input | ||
| from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( | ||
| dequantize_to_dtype, | ||
| ) | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| QuantKey, | ||
| kNvfp4Dynamic, | ||
| kNvfp4Static, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class Nvfp4QuantizationEmulationTritonExperts(TritonExperts): | ||
| """ | ||
| Extension of TritonExperts to support emulated NVFP4 MoE experts. | ||
|
|
||
| It may be used for NVFP4 models when the device does not have | ||
| native support for this dtype. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| moe_config: FusedMoEConfig, | ||
| quant_config: FusedMoEQuantConfig, | ||
| ): | ||
| super().__init__(moe_config, quant_config) | ||
| logger.warning_once( | ||
| "Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will" | ||
| " dequantize weights on the fly and may be slower than native" | ||
| " quantized MOE. Consider using a device with native quantization" | ||
| " support (e.g. Nvidia Blackwell) for better performance." | ||
| ) | ||
|
|
||
| # `TritonExperts.apply` expects pre-dequantized weights, | ||
| # which we handle in `apply` below. | ||
| self.w1_scale_val = self.quant_config.w1_scale | ||
| self.w2_scale_val = self.quant_config.w2_scale | ||
|
|
||
| self.quant_config._w1.scale = None | ||
| self.quant_config._w2.scale = None | ||
|
|
||
| self.quantization_emulation = True | ||
|
|
||
| @property | ||
| def quant_dtype(self) -> torch.dtype | str | None: | ||
| return "nvfp4" | ||
|
|
||
| @property | ||
| def expects_unquantized_inputs(self) -> bool: | ||
| return True | ||
|
|
||
| @staticmethod | ||
| def _supports_quant_scheme( | ||
| weight_key: QuantKey | None, | ||
| activation_key: QuantKey | None, | ||
| ) -> bool: | ||
| return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) | ||
|
|
||
| def apply( | ||
| self, | ||
| output: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| activation: MoEActivation, | ||
| global_num_experts: int, | ||
| expert_map: torch.Tensor | None, | ||
| a1q_scale: torch.Tensor | None, | ||
| a2_scale: torch.Tensor | None, | ||
| workspace13: torch.Tensor, | ||
| workspace2: torch.Tensor, | ||
| expert_tokens_meta: mk.ExpertTokensMetadata | None, | ||
| apply_router_weight_on_input: bool, | ||
| ): | ||
| """ | ||
| Apply emulated quantized MoE computation. | ||
|
|
||
| This dequantizes the weights on the fly and calls fused_experts_impl | ||
| with activation quantization support. | ||
| """ | ||
| # Dequantize weights if they are quantized | ||
| # For NVFP4, weights are packed in uint8 format | ||
| # w1 shape: [num_experts, 2*intermediate_size, hidden_size//2] | ||
| # w2 shape: [num_experts, hidden_size, intermediate_size//2] | ||
| assert w1.dtype == torch.uint8 | ||
| assert w2.dtype == torch.uint8 | ||
|
|
||
| # Dequantize w1 from packed NVFP4 to fp16/bf16 | ||
| w13_global_scale = self.quant_config.g1_alphas | ||
|
|
||
| w1_dequant = dequantize_to_dtype( | ||
| tensor_fp4=w1, | ||
| tensor_sf=self.w1_scale_val, | ||
| global_scale=w13_global_scale, | ||
| dtype=hidden_states.dtype, | ||
| block_size=16, | ||
| swizzle=False, | ||
| ) | ||
|
|
||
| # Dequantize w2 from packed NVFP4 to fp16/bf16 | ||
| w2_global_scale = self.quant_config.g2_alphas | ||
|
|
||
| w2_dequant = dequantize_to_dtype( | ||
| tensor_fp4=w2, | ||
| tensor_sf=self.w2_scale_val, | ||
| global_scale=w2_global_scale, | ||
| dtype=hidden_states.dtype, | ||
| block_size=16, | ||
| swizzle=False, | ||
| ) | ||
|
|
||
| hidden_states, _ = moe_kernel_quantize_input( | ||
| A=hidden_states, | ||
| A_scale=self.quant_config.a1_gscale, | ||
| quant_dtype="nvfp4", | ||
| per_act_token_quant=False, | ||
| quantization_emulation=True, | ||
| ) | ||
|
|
||
| # Activation quantization/dequantization is deferred to | ||
| # `moe_kernel_quantize_input` in TritonExperts.apply. | ||
| super().apply( | ||
| output=output, | ||
| hidden_states=hidden_states, | ||
| w1=w1_dequant, | ||
| w2=w2_dequant, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| activation=activation, | ||
| global_num_experts=global_num_experts, | ||
| expert_map=expert_map, | ||
| a1q_scale=None, | ||
| a2_scale=self.quant_config.a2_gscale, | ||
| workspace13=workspace13, | ||
| workspace2=workspace2, | ||
| expert_tokens_meta=expert_tokens_meta, | ||
| apply_router_weight_on_input=apply_router_weight_on_input, | ||
| ) |
186 changes: 186 additions & 0 deletions
186
vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| OCP MX quantization emulation for MoE. | ||
|
|
||
| This file implements OCP MX (MXFP4/MXFP6) emulation for MoE in case the | ||
| hardware used does not natively support OCP MX MoE. | ||
|
|
||
| Weights are dequantized on the fly during each forward, we fall back to calling | ||
| `TritonExperts` using BF16, and fake OCP MX quantize-dequantize | ||
| is applied on activations via `moe_kernel_quantize_input`. | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.fused_moe.activation import MoEActivation | ||
| from vllm.model_executor.layers.fused_moe.config import ( | ||
| FusedMoEConfig, | ||
| FusedMoEQuantConfig, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts | ||
| from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input | ||
| from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 | ||
| from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 | ||
| from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( | ||
| OCP_MX_Scheme, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class OCP_MXQuantizationEmulationTritonExperts(TritonExperts): | ||
| """ | ||
| Extension of TritonExperts to support emulated OCP MX MoE experts. | ||
|
|
||
| It may be used for OCP MX (MXFP4/MXFP6) models when the device does not | ||
| have native support for these dtypes. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| moe_config: FusedMoEConfig, | ||
| quant_config: FusedMoEQuantConfig, | ||
| ): | ||
| super().__init__(moe_config, quant_config) | ||
| logger.warning_once( | ||
| "Using OCP_MXQuantizationEmulationTritonExperts MOE backend. This" | ||
| " will dequantize weights on the fly and may be slower than native" | ||
| " quantized MOE. Consider using a device with native OCP MX" | ||
| " quantization support for better performance." | ||
| ) | ||
|
|
||
| self.ocp_mx_scheme = quant_config.ocp_mx_scheme | ||
| assert self.ocp_mx_scheme is not None, ( | ||
| "ocp_mx_scheme must be set in quant_config for" | ||
| " OCP_MXQuantizationEmulationTritonExperts" | ||
| ) | ||
|
|
||
| # `TritonExperts.apply` expects pre-dequantized weights, | ||
| # which we handle in `apply` below. | ||
| self.w1_scale_val = self.quant_config.w1_scale | ||
| self.w2_scale_val = self.quant_config.w2_scale | ||
|
|
||
| self.quant_config._w1.scale = None | ||
| self.quant_config._w2.scale = None | ||
|
|
||
| self.quantization_emulation = True | ||
|
|
||
| if self.ocp_mx_scheme in { | ||
| OCP_MX_Scheme.w_mxfp4_a_mxfp4, | ||
| }: | ||
| # Weight has to be dequantized for mxfp4 emulation. | ||
| self._quant_dtype = "mxfp4" | ||
| elif self.ocp_mx_scheme in [ | ||
| OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, | ||
| OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, | ||
| OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2, | ||
| OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3, | ||
| ]: | ||
| self._quant_dtype = "mxfp6" | ||
| elif self.ocp_mx_scheme in [ | ||
| OCP_MX_Scheme.w_mxfp4_a_fp8, | ||
| OCP_MX_Scheme.w_mxfp6_e3m2_a_fp8, | ||
| ]: | ||
| # TODO: double check this one | ||
| self._quant_dtype = "mxfp8" | ||
|
|
||
| @property | ||
| def quant_dtype(self) -> torch.dtype | str | None: | ||
| return self._quant_dtype | ||
|
|
||
| @property | ||
| def expects_unquantized_inputs(self) -> bool: | ||
| return True | ||
|
|
||
| @staticmethod | ||
| def _supports_quant_scheme( | ||
| weight_key, | ||
| activation_key, | ||
| ) -> bool: | ||
| # This class is used for emulation only - the oracle selects it | ||
| # directly rather than via quant scheme matching. | ||
| return True | ||
|
|
||
| def _dequantize_weights( | ||
| self, | ||
| w: torch.Tensor, | ||
| w_scale: torch.Tensor, | ||
| dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| """Dequantize weights based on the OCP MX scheme.""" | ||
| if self.ocp_mx_scheme.startswith("w_mxfp4"): # type: ignore[union-attr] | ||
| return dequant_mxfp4(w, w_scale, dtype) | ||
| elif self.ocp_mx_scheme.startswith("w_mxfp6_e3m2"): # type: ignore[union-attr] | ||
| return dequant_mxfp6(w, w_scale, quant_dtype="fp6_e3m2", float_dtype=dtype) | ||
| elif self.ocp_mx_scheme.startswith("w_mxfp6_e2m3"): # type: ignore[union-attr] | ||
| return dequant_mxfp6(w, w_scale, quant_dtype="fp6_e2m3", float_dtype=dtype) | ||
| else: | ||
| raise NotImplementedError(f"Unsupported ocp_mx_scheme={self.ocp_mx_scheme}") | ||
|
|
||
| def apply( | ||
| self, | ||
| output: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| activation: MoEActivation, | ||
| global_num_experts: int, | ||
| expert_map: torch.Tensor | None, | ||
| a1q_scale: torch.Tensor | None, | ||
| a2_scale: torch.Tensor | None, | ||
| workspace13: torch.Tensor, | ||
| workspace2: torch.Tensor, | ||
| expert_tokens_meta: mk.ExpertTokensMetadata | None, | ||
| apply_router_weight_on_input: bool, | ||
| ): | ||
| """ | ||
| Apply emulated quantized MoE computation. | ||
|
|
||
| This dequantizes the weights on the fly and calls TritonExperts.apply | ||
| with activation quantization support. | ||
| """ | ||
| assert w1.dtype == torch.uint8 | ||
| assert w2.dtype == torch.uint8 | ||
|
|
||
| # Dequantize w1 and w2 from packed OCP MX format to bf16/fp16 | ||
| w1_dequant = self._dequantize_weights( | ||
| w1, self.w1_scale_val, hidden_states.dtype | ||
| ) | ||
| w2_dequant = self._dequantize_weights( | ||
| w2, self.w2_scale_val, hidden_states.dtype | ||
| ) | ||
|
|
||
| # Apply activation QDQ if needed by the OCP MX scheme | ||
| hidden_states, _ = moe_kernel_quantize_input( | ||
| A=hidden_states, | ||
| A_scale=None, | ||
| quant_dtype=self.quant_config.quant_dtype, | ||
| per_act_token_quant=False, | ||
| ocp_mx_scheme=self.ocp_mx_scheme, | ||
| quantization_emulation=True, | ||
| ) | ||
|
|
||
| # Activation quantization/dequantization is deferred to | ||
| # `moe_kernel_quantize_input` in TritonExperts.apply. | ||
| super().apply( | ||
| output=output, | ||
| hidden_states=hidden_states, | ||
| w1=w1_dequant, | ||
| w2=w2_dequant, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| activation=activation, | ||
| global_num_experts=global_num_experts, | ||
| expert_map=expert_map, | ||
| a1q_scale=None, | ||
| a2_scale=None, | ||
| workspace13=workspace13, | ||
| workspace2=workspace2, | ||
| expert_tokens_meta=expert_tokens_meta, | ||
| apply_router_weight_on_input=apply_router_weight_on_input, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context, this is running on
vllm/.buildkite/test-amd.yaml
Lines 2695 to 2715 in 8d825b8
and I did ran it successfully locally.