Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
MXFP8_VALUE_DTYPE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
swizzle_mxfp8_scale,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
Expand Down Expand Up @@ -1689,9 +1690,9 @@ def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
"Dynamic quantization is not supported."
)

backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)
self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)

def create_weights(
self,
Expand Down Expand Up @@ -1749,7 +1750,38 @@ def create_weights(
)
layer.register_parameter("weight_scale", weight_scale)

def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
"""Not swizzled - MXFP8 GEMM emulation"""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE

# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()

layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
weight = layer.weight.data # [N, K]
N, K = weight.shape

# 2D weight scale
weight_scale = layer.weight_scale.data

# Swizzle the weight scales
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)

layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(
weight_scale_swizzled.contiguous(), requires_grad=False
)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Validate weight tensor
if layer.weight.ndim != 2:
raise ValueError(
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
Expand All @@ -1763,15 +1795,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
f"quantized with MXFP8."
)

weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
# Validate weight scale tensor (should be 2D, not swizzled)
assert layer.weight_scale.ndim == 2, (
f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
)
assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
f" got {layer.weight_scale.dtype}"
)

# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
if self.backend == Mxfp8LinearBackend.EMULATION:
# Swizzled layout is not used
self._process_weights_after_loading_scale_2d(layer)
return

layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
# Swizzled layout is required for Flashinfer CUTLASS
self._process_weights_after_loading_scale_1d(layer)

def apply(
self,
Expand Down
104 changes: 103 additions & 1 deletion vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import torch

from vllm.logger import init_logger
from vllm.utils import flashinfer as vllm_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op

logger = init_logger(__name__)


class Mxfp8LinearBackend(Enum):
EMULATION = "emulation"
FLASHINFER_CUTLASS = "flashinfer-cutlass"


# MXFP8 constants
Expand All @@ -21,6 +23,30 @@ class Mxfp8LinearBackend(Enum):
MXFP8_BLOCK_SIZE = 32


def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
factor = scaling_vector_size * 4 # 128

num_m_tiles = (M + 127) // 128
num_k_tiles = (K + factor - 1) // factor

m_padded = num_m_tiles * 128
k_scale_padded = num_k_tiles * 4

scale_cols = K // scaling_vector_size
sf_padded = torch.zeros(
(m_padded, k_scale_padded), dtype=sf.dtype, device=sf.device
)
sf_padded[:M, :scale_cols] = sf

sf_reshaped = sf_padded.view(num_m_tiles, 4, 32, num_k_tiles, 4)

sf_swizzled = sf_reshaped.transpose(1, 3)

return sf_swizzled.contiguous().view(-1)


def _mxfp8_e4m3_quantize_impl(
x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -108,7 +134,7 @@ def __init__(self, backend: Mxfp8LinearBackend):

self.backend = backend

def apply(
def _apply_emulation(
self,
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -132,3 +158,79 @@ def apply(

output = torch.nn.functional.linear(input, weight_bf16, bias)
return output.to(out_dtype)

def _apply_flashinfer_cutlass(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
N, K = weight.shape

input_shape = input.shape
input_2d = input.view(-1, K)
M_orig = input_2d.shape[0]

# Minimum dimension size for F8_128x4 block scaling layout
min_dim = 128

assert min_dim <= K, (
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
f"in_features is too small for mm_mxfp8."
)
assert K % MXFP8_BLOCK_SIZE == 0, (
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
)
assert min_dim <= N, (
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
f"out_features is too small for mm_mxfp8."
)
Comment on lines +179 to +189
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These hard assertions on min_dim=128 for K and N will cause vLLM to crash if the model contains any linear layers with dimensions smaller than 128 (e.g., router gates or small projection layers). Instead of crashing, the implementation should detect unsupported shapes and fall back to the EMULATION backend for those specific layers. Note that this requires ensuring the weight scales are processed correctly (not swizzled) for the fallback backend during the weight loading phase.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to fall back to emulation and instead raise an error.

The emulation has lower performance compared to cutlass and users may not notice that fallback was triggered.

The ModelOpt MXFP8 support is new, changes to backend selection logic can be added later as needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertions can be thrown before kernel execution, for example, in post weight processing if we recognize that the model has incompatible shapes for the given kernel backend. If kernel apply is only point of failure / check, it would error out much later, only when the kernel is invoked.

Copy link
Copy Markdown
Contributor Author

@danisereb danisereb Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @mgoin merges his PR #34664 first (Marlin MXFP8 GEMM) I will align my PR to his.
In that case I'll add a select_mxfp8_linear_backend function that will select cutlass / marlin / emulation backend (fallback to marlin if cutlass is not supported).

Maybe an assert should be used only if the user uses an env-var to force cutlass MXFP8 GEMM (or follow similar logic to existing NVFP4 / FP8 "select_*_backend").


M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
if M_padded != M_orig:
pad_rows = M_padded - M_orig
input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))

input_mxfp8, input_scale = mxfp8_e4m3_quantize(
input_2d,
is_sf_swizzled_layout=True, # Swizzled for best accuracy
)

if not weight.is_contiguous():
weight = weight.contiguous()

output = vllm_flashinfer.mm_mxfp8(
input_mxfp8,
weight.t(),
input_scale,
weight_scale,
out_dtype=out_dtype,
backend="cutlass",
)

if M_padded != M_orig:
output = output[:M_orig, :]

if bias is not None:
output = output + bias

output_shape = (*input_shape[:-1], N)
return output.view(output_shape)

def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.backend == Mxfp8LinearBackend.EMULATION:
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)

assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
return self._apply_flashinfer_cutlass(
input, weight, weight_scale, out_dtype, bias
)
77 changes: 77 additions & 0 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,83 @@ def flashinfer_nvfp4_quantize_fake(
rounded_m, rounded_n, dtype=torch.uint8, device=a.device
)

@torch.library.custom_op(
"vllm::mm_mxfp8",
mutates_args=[],
device_types="cuda",
)
def mm_mxfp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: torch.dtype,
backend: str = "cutlass",
) -> torch.Tensor:
from flashinfer import mm_mxfp8 as mm_mxfp8_

return mm_mxfp8_(
A,
B,
A_scale,
B_scale,
out=None,
out_dtype=out_dtype,
backend=backend,
)

@torch.library.register_fake(
"vllm::mm_mxfp8",
)
def mm_mxfp8_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: torch.dtype,
backend: str = "cutlass",
) -> torch.Tensor:
# A is [m, k], B is [k, n] -> output [m, n]
return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device)


def flashinfer_mm_mxfp8(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
out_dtype: torch.dtype,
backend: str = "cutlass",
) -> torch.Tensor:
"""MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.

Takes non-transposed weights and handles transpose internally.

CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
performance and accuracy. Both input and weight scales should be in
swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
"""
# a shape [M, K]
# b shape [K, N]
assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[1] # K dimension must match

if block_scale_b.ndim != 1:
raise ValueError(
"mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
f"got shape={tuple(block_scale_b.shape)}"
)

# Output tensor [M, N]
return mm_mxfp8(
a,
b.t(), # Transpose weight: [N, K] -> [K, N]
block_scale_a,
block_scale_b,
out_dtype,
backend=backend,
)


def flashinfer_scaled_fp4_mm(
a: torch.Tensor,
Expand Down