Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ steps:
- tests/models/multimodal
no_gpu: true
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py

- label: Multi-Modal Processor Test
Expand Down
91 changes: 91 additions & 0 deletions tests/kernels/quantization/test_scaled_mm_kernel_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ScaledMM kernel selection logic (CPU-only)

Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
"""

import inspect
from abc import ABC

import pytest

from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
)

pytestmark = pytest.mark.cpu_test


def test_is_supported_is_abstract():
"""Test that is_supported() is properly defined as abstract."""
assert issubclass(ScaledMMLinearKernel, ABC)
assert hasattr(ScaledMMLinearKernel, "is_supported")


def test_cpu_kernel_implements_is_supported():
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
"CPUScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
CPUScaledMMLinearKernel.is_supported
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
# Verify it can be called as a classmethod
result, reason = CPUScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"


def test_aiter_kernel_implements_is_supported():
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
"AiterScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(
AiterScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
# (will return False on CPU, which is expected)
result, reason = AiterScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm
# This validates the method works correctly even on non-ROCm platforms


def test_cpu_kernel_accepts_all_configs():
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
configs = [
ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=True,
),
ScaledMMLinearLayerConfig(
is_channelwise=True,
is_static_input_scheme=False,
input_symmetric=False,
),
]

for config in configs:
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
assert can_impl, (
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
def is_supported(
Comment thread
shivampr marked this conversation as resolved.
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError

@classmethod
Expand All @@ -35,6 +37,7 @@ def __init__(
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
Expand Down Expand Up @@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
type[ScaledMMLinearKernel]: Chosen kernel.
"""

if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]

failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue

# If the current platform uses compute_capability,
# make sure the kernel supports the compute cability.
if compute_capability is not None:
kernel_min_capability = kernel.get_min_capability()
if (
kernel_min_capability is not None
and kernel_min_capability > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel_min_capability}, current compute capability "
f"is {compute_capability}"
)
continue
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue

can_implement, reason = kernel.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue

can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
return kernel

raise ValueError(
"Failed to find a kernel that can implement the "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@

class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def is_supported(
Comment thread
shivampr marked this conversation as resolved.
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}"

try:
import aiter # noqa: F401 # deliberately attempt to import aiter
Expand All @@ -34,8 +38,8 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (rocm_aiter_ops.is_linear_enabled()):

if not rocm_aiter_ops.is_linear_enabled():
return (
False,
"AiterScaledMMLinearKernel is disabled. "
Expand All @@ -44,6 +48,10 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)

return True, None

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return (
False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Requires CPU."
return True, None

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "CPUScaledMM requires running on CPU."

return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@

class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "Requires CUDA."
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 75:
return False, f"requires capability 75, got {compute_capability}"
return True, None

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA."

return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,68 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig


class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "Requires ROCm or CUDA."

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not "
+ "currently supported on CPU.",
)
if not c.input_symmetric:
return (
False,
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return False, "Only symmetric input is supported."
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
Comment thread
shivampr marked this conversation as resolved.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)

# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)

setattr(layer, self.azp_adj_name, None)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True
)

assert x_zp is None, "Triton kernel only supports symmetric quantization"

return triton_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"TPU platform does have a concept of compute capability, "
"this method should not be called."
)
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
Expand Down