Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
Expand Down Expand Up @@ -1016,9 +1016,9 @@ def _get_or_set_default() -> str:
in ("true", "1")
),
# Whether to use aiter fusion shared experts ops.
# By default is disabled.
# Enabled by default for better MoE performance (matching ATOM defaults).
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower()
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
in ("true", "1")
),
# Whether to use aiter triton kernels for gemm ops.
Expand Down
122 changes: 119 additions & 3 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
Expand All @@ -25,13 +26,25 @@
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs

logger = init_logger(__name__)
Expand Down Expand Up @@ -72,9 +85,21 @@
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled():

Check failure on line 88 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "current_platform" is not defined [name-defined]

Check failure on line 88 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "current_platform" is not defined [name-defined]

Check failure on line 88 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/layers/quantization/mxfp4.py:88:16: F821 Undefined name `current_platform`
logger.info_once(
"Using AITER MXFP4 linear method on ROCm.",
scope="local",
)
return Mxfp4LinearMethod()
if current_platform.is_cuda():

Check failure on line 94 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "current_platform" is not defined [name-defined]

Check failure on line 94 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "current_platform" is not defined [name-defined]

Check failure on line 94 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/layers/quantization/mxfp4.py:94:16: F821 Undefined name `current_platform`
logger.info_once(
"Using Marlin MXFP4 linear method on CUDA.",
scope="local",
)
return Mxfp4LinearMethod()
logger.debug_once(
"MXFP4 linear layer is not implemented - falling back to "
"UnquantizedLinearMethod.",
"MXFP4 linear layer is not supported on this platform "
"- falling back to UnquantizedLinearMethod.",
scope="local",
)
return UnquantizedLinearMethod()
Expand All @@ -93,6 +118,97 @@
return True


class Mxfp4LinearMethod(LinearMethodBase):
"""MXFP4 quantized linear method.

On ROCm: Uses AITER's Triton FP4 GEMM (gemm_afp4wfp4) with dynamic
activation quantization, following the same kernel path as ATOM.
On CUDA: Uses the Marlin FP4 kernel.
"""

MXFP4_BLOCK_SIZE = 32

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
output_size_per_partition = sum(output_partition_sizes)

layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.MXFP4_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled():

Check failure on line 175 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "current_platform" is not defined [name-defined]

Check failure on line 175 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/layers/quantization/mxfp4.py:175:12: F821 Undefined name `current_platform`
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# Transpose scale so that triton_fp4_gemm_dynamic_qaunt's
# internal .T produces the [N, K/32] layout the kernel expects.
Comment on lines +177 to +178
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment is confusing. It refers to an "internal .T" in triton_fp4_gemm_dynamic_qaunt. This function is defined in vllm/_aiter_ops.py and explicitly transposes weight_scale before passing it to the gemm_afp4wfp4 kernel. The current implementation correctly pre-transposes the scale to cancel out this operation, but the comment is misleading and could cause confusion during future maintenance. A clearer comment would improve maintainability and prevent potential bugs.

Suggested change
# Transpose scale so that triton_fp4_gemm_dynamic_qaunt's
# internal .T produces the [N, K/32] layout the kernel expects.
# The `triton_fp4_gemm_dynamic_qaunt` function transposes `weight_scale`.
# We pre-transpose it here to cancel that out.

layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False
)
else:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
prepare_fp4_layer_for_marlin(layer)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled():

Check failure on line 192 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "current_platform" is not defined [name-defined]

Check failure on line 192 in vllm/model_executor/layers/quantization/mxfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/layers/quantization/mxfp4.py:192:12: F821 Undefined name `current_platform`
out = rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt(
x, layer.weight, layer.weight_scale, torch.bfloat16
)
if bias is not None:
out = out + bias
return out
else:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=None,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)


class Mxfp4MoEMethod(FusedMoEMethodBase):
"""MXFP4 MoE quantization method."""

Expand Down
Loading