Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CompressedTensorsConfig,
CompressedTensorsLinearMethod,
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
Expand Down Expand Up @@ -689,3 +690,31 @@ def check_model(model):
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
assert output


@pytest.mark.skipif(
not current_platform.is_cuda() or not current_platform.has_device_capability(80),
reason="MXFP4 requires ampere or newer",
)
def test_compressed_tensors_mxfp4(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-MXFP4"
with vllm_runner(model_path, enforce_eager=True) as llm:

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj

for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
assert isinstance(proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(proj.scheme, CompressedTensorsW4A4Mxfp4)

# Verify group size
assert proj.scheme.group_size == 32

llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
assert output
68 changes: 68 additions & 0 deletions vllm/model_executor/kernels/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.mxfp4 import (
MxFp4LinearKernel,
MxFp4LinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mxfp4.flashinfer import (
FlashInferMxFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.mxfp4.marlin import (
MarlinMxFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.mxfp8 import (
Mxfp8LinearKernel,
Mxfp8LinearLayerConfig,
Expand Down Expand Up @@ -276,6 +286,13 @@
],
}

_POSSIBLE_MXFP4_KERNELS: dict[PlatformEnum, list[type[MxFp4LinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferMxFp4LinearKernel,
MarlinMxFp4LinearKernel,
],
}

# TODO make all kernels inherit from MMLinearKernel
# then bound _KernelT only to MMLinearKernel
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel)
Expand Down Expand Up @@ -570,6 +587,48 @@ def init_mxfp8_linear_kernel() -> Mxfp8LinearKernel:
)


def init_mxfp4_linear_kernel() -> MxFp4LinearKernel:
"""Select and instantiate the best MXFP4 linear kernel for the
current platform."""
force_kernel: type[MxFp4LinearKernel] | None = None
if envs.VLLM_MXFP4_USE_MARLIN:
force_kernel = MarlinMxFp4LinearKernel

if force_kernel is not None:
is_supported, reason = force_kernel.is_supported()
if not is_supported:
raise ValueError(
f"Forced MXFP4 kernel {force_kernel.__name__} is not "
f"supported: {reason}"
)
logger.info_once("Using %s for MXFP4 GEMM", force_kernel.__name__)
return force_kernel(MxFp4LinearLayerConfig())

platform = current_platform._enum
possible = _POSSIBLE_MXFP4_KERNELS.get(platform, [])

failure_reasons = []
for kernel_cls in possible:
if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel_cls.__name__} disabled by environment variable"
)
continue

is_supported, reason = kernel_cls.is_supported()
if not is_supported:
failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
continue

logger.info_once("Using %s for MXFP4 GEMM", kernel_cls.__name__)
return kernel_cls(MxFp4LinearLayerConfig())

raise ValueError(
"Failed to find a kernel that can implement the "
"MXFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)


def init_wfp8_a16_linear_kernel(
weight_quant_key: QuantKey,
activation_quant_key: QuantKey,
Expand Down Expand Up @@ -730,6 +789,10 @@ def register_linear_kernel(
if platform not in _POSSIBLE_NVFP4_KERNELS:
_POSSIBLE_NVFP4_KERNELS[platform] = []
_POSSIBLE_NVFP4_KERNELS[platform].append(kernel_class)
elif kernel_type == "mxfp4":
if platform not in _POSSIBLE_MXFP4_KERNELS:
_POSSIBLE_MXFP4_KERNELS[platform] = []
_POSSIBLE_MXFP4_KERNELS[platform].append(kernel_class)
else:
raise ValueError(f"Unrecognized kernel type: {kernel_type}")

Expand Down Expand Up @@ -777,6 +840,11 @@ def register_linear_kernel(
"init_mxfp8_linear_kernel",
"Mxfp8LinearKernel",
"Mxfp8LinearLayerConfig",
"init_mxfp4_linear_kernel",
"MxFp4LinearKernel",
"MxFp4LinearLayerConfig",
"FlashInferMxFp4LinearKernel",
"MarlinMxFp4LinearKernel",
"FlashInferCutlassMxfp8LinearKernel",
"MarlinMxfp8LinearKernel",
"XPUMxFp8LinearKernel",
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/kernels/linear/mxfp4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.model_executor.kernels.linear.mxfp4.base import (
MxFp4LinearKernel,
MxFp4LinearLayerConfig,
)

__all__ = [
"MxFp4LinearKernel",
"MxFp4LinearLayerConfig",
]
67 changes: 67 additions & 0 deletions vllm/model_executor/kernels/linear/mxfp4/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch


@dataclass
class MxFp4LinearLayerConfig:
"""Configuration for an MXFP4 linear layer.

All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per
byte) and per-block weight scales (group size 32).
"""

pass


class MxFp4LinearKernel(ABC):
"""Base class for MXFP4 quantized linear kernels.

Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc).
The kernel selection mechanism iterates over registered subclasses in
priority order,calling ``is_supported`` and ``can_implement`` to find the best
match for the current hardware.
"""

def __init__(self, config: MxFp4LinearLayerConfig) -> None:
assert self.can_implement(config)[0]
assert self.is_supported()[0]
self.config = config

@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
"""Return whether this kernel can run on the current platform."""
raise NotImplementedError

@classmethod
@abstractmethod
def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
"""Return whether this kernel can handle *config*."""
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Transform weights into the format required by this kernel.

Called once after checkpoint weights have been loaded onto the
device. Implementations should repack / swizzle / pad weights
and scales in-place on *layer*.
"""
raise NotImplementedError

@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Run the quantized GEMM."""
raise NotImplementedError
74 changes: 74 additions & 0 deletions vllm/model_executor/kernels/linear/mxfp4/flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
swizzle_mxfp4_scales,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutedsl

from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig

_MXFP4_GROUP_SIZE = 32


class FlashInferMxFp4LinearKernel(MxFp4LinearKernel):
"""MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+)."""

@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.has_device_capability(100) and has_flashinfer_cutedsl():
return True, None
return False, "FlashInfer + >=sm_100 (Blackwell) required"

@classmethod
def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
N, scale_K = layer.weight_scale.shape
K = scale_K * _MXFP4_GROUP_SIZE

# swizzle pads N to the next multiple of 128 for CUTLASS tiling
padded_N = ((N + 127) // 128) * 128
layer.weight_scale = Parameter(
swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1),
requires_grad=False,
)
Comment thread
dsikka marked this conversation as resolved.

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.utils.flashinfer import (
flashinfer_mxfp4_quantize,
flashinfer_scaled_fp4_mm,
)

weight = layer.weight
out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
x_2d = x.reshape(-1, x.shape[-1])

x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d)
out = flashinfer_scaled_fp4_mm(
x_fp4,
weight,
x_scale,
layer.weight_scale,
alpha=None,
out_dtype=x.dtype,
backend="cute-dsl",
block_size=_MXFP4_GROUP_SIZE,
use_nvfp4=False,
)

if bias is not None:
out = out + bias
return out.view(out_shape)
52 changes: 52 additions & 0 deletions vllm/model_executor/kernels/linear/mxfp4/marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig


class MarlinMxFp4LinearKernel(MxFp4LinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)

if is_fp4_marlin_supported():
return True, None
return False, "Marlin FP4 not available"

@classmethod
def can_implement(cls, c: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_fp4_layer_for_marlin,
)

prepare_fp4_layer_for_marlin(layer)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
)

return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=None,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
CompressedTensors24,
CompressedTensorsScheme,
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Mxfp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
Expand Down Expand Up @@ -625,7 +625,7 @@ def _get_scheme_from_parts(
return CompressedTensorsW4A16Fp4()

if self._is_mxfp4(weight_quant):
return CompressedTensorsW4A16Mxfp4()
return CompressedTensorsW4A4Mxfp4()

if self._is_mxfp8(weight_quant):
return CompressedTensorsW8A8Mxfp8()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, moe):
super().__init__(moe)
self.group_size = 32
self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
# use cutlass if supported, otherwise fallback to marlin for weight-only FP4
self.use_cutlass_mxfp4 = CutlassExpertsMxfp4._supports_current_device()
self.experts_cls: type[mk.FusedMoEExperts]
if self.use_cutlass_mxfp4:
Expand Down
Loading
Loading