-
-
Notifications
You must be signed in to change notification settings - Fork 17.6k
[MXFP4] Support for linear layers + compressed-tensors integration #41664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
baafb30
update
dsikka 8e28f85
update
dsikka f935818
use linear kernel abstraction
dsikka fadc701
add test models
dsikka aade555
Merge branch 'main' into ct_mxfp4
dsikka e50f724
fix padded mx linear
kylesayrs eda7aa3
update dummy shape
kylesayrs 8efe81c
Merge branch 'main' into ct_mxfp4
dsikka 13d2a97
Merge branch 'main' into ct_mxfp4
dsikka fca2eed
Apply suggestions from code review
dsikka ece7377
Merge branch 'main' into ct_mxfp4
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from vllm.model_executor.kernels.linear.mxfp4.base import ( | ||
| MxFp4LinearKernel, | ||
| MxFp4LinearLayerConfig, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "MxFp4LinearKernel", | ||
| "MxFp4LinearLayerConfig", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| @dataclass | ||
| class MxFp4LinearLayerConfig: | ||
| """Configuration for an MXFP4 linear layer. | ||
|
|
||
| All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per | ||
| byte) and per-block weight scales (group size 32). | ||
| """ | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class MxFp4LinearKernel(ABC): | ||
| """Base class for MXFP4 quantized linear kernels. | ||
|
|
||
| Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc). | ||
| The kernel selection mechanism iterates over registered subclasses in | ||
| priority order,calling ``is_supported`` and ``can_implement`` to find the best | ||
| match for the current hardware. | ||
| """ | ||
|
|
||
| def __init__(self, config: MxFp4LinearLayerConfig) -> None: | ||
| assert self.can_implement(config)[0] | ||
| assert self.is_supported()[0] | ||
| self.config = config | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def is_supported( | ||
| cls, compute_capability: int | None = None | ||
| ) -> tuple[bool, str | None]: | ||
| """Return whether this kernel can run on the current platform.""" | ||
| raise NotImplementedError | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: | ||
| """Return whether this kernel can handle *config*.""" | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| """Transform weights into the format required by this kernel. | ||
|
|
||
| Called once after checkpoint weights have been loaded onto the | ||
| device. Implementations should repack / swizzle / pad weights | ||
| and scales in-place on *layer*. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """Run the quantized GEMM.""" | ||
| raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
| from torch.nn.parameter import Parameter | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( | ||
| swizzle_mxfp4_scales, | ||
| ) | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.flashinfer import has_flashinfer_cutedsl | ||
|
|
||
| from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig | ||
|
|
||
| _MXFP4_GROUP_SIZE = 32 | ||
|
|
||
|
|
||
| class FlashInferMxFp4LinearKernel(MxFp4LinearKernel): | ||
| """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+).""" | ||
|
|
||
| @classmethod | ||
| def is_supported( | ||
| cls, compute_capability: int | None = None | ||
| ) -> tuple[bool, str | None]: | ||
| if current_platform.has_device_capability(100) and has_flashinfer_cutedsl(): | ||
| return True, None | ||
| return False, "FlashInfer + >=sm_100 (Blackwell) required" | ||
|
|
||
| @classmethod | ||
| def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: | ||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| N, scale_K = layer.weight_scale.shape | ||
| K = scale_K * _MXFP4_GROUP_SIZE | ||
|
|
||
| # swizzle pads N to the next multiple of 128 for CUTLASS tiling | ||
| padded_N = ((N + 127) // 128) * 128 | ||
| layer.weight_scale = Parameter( | ||
| swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1), | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| from vllm.utils.flashinfer import ( | ||
| flashinfer_mxfp4_quantize, | ||
| flashinfer_scaled_fp4_mm, | ||
| ) | ||
|
|
||
| weight = layer.weight | ||
| out_shape = x.shape[:-1] + (layer.output_size_per_partition,) | ||
| x_2d = x.reshape(-1, x.shape[-1]) | ||
|
|
||
| x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d) | ||
| out = flashinfer_scaled_fp4_mm( | ||
| x_fp4, | ||
| weight, | ||
| x_scale, | ||
| layer.weight_scale, | ||
| alpha=None, | ||
| out_dtype=x.dtype, | ||
| backend="cute-dsl", | ||
| block_size=_MXFP4_GROUP_SIZE, | ||
| use_nvfp4=False, | ||
| ) | ||
|
|
||
| if bias is not None: | ||
| out = out + bias | ||
| return out.view(out_shape) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
|
|
||
| from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig | ||
|
|
||
|
|
||
| class MarlinMxFp4LinearKernel(MxFp4LinearKernel): | ||
| @classmethod | ||
| def is_supported( | ||
| cls, compute_capability: int | None = None | ||
| ) -> tuple[bool, str | None]: | ||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( | ||
| is_fp4_marlin_supported, | ||
| ) | ||
|
|
||
| if is_fp4_marlin_supported(): | ||
| return True, None | ||
| return False, "Marlin FP4 not available" | ||
|
|
||
| @classmethod | ||
| def can_implement(cls, c: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: | ||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( | ||
| prepare_fp4_layer_for_marlin, | ||
| ) | ||
|
|
||
| prepare_fp4_layer_for_marlin(layer) | ||
|
|
||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( | ||
| apply_fp4_marlin_linear, | ||
| ) | ||
|
|
||
| return apply_fp4_marlin_linear( | ||
| input=x, | ||
| weight=layer.weight, | ||
| weight_scale=layer.weight_scale, | ||
| weight_global_scale=None, | ||
| workspace=layer.workspace, | ||
| size_n=layer.output_size_per_partition, | ||
| size_k=layer.input_size_per_partition, | ||
| bias=bias, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.