Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 68 additions & 17 deletions vllm/model_executor/kernels/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear.base import (
MMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision import (
MPLinearKernel,
MPLinearLayerConfig,
Expand Down Expand Up @@ -52,20 +55,25 @@
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel,
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cuda import (
CudaFp8BlockScaledMMKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
Expand All @@ -81,6 +89,7 @@
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonFp8BlockScaledMMKernel,
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
Expand Down Expand Up @@ -124,6 +133,22 @@
],
}


# in priority/performance order (when available)
_POSSIBLE_FP8_BLOCK_KERNELS: dict[
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]]
] = {
PlatformEnum.CUDA: [
CudaFp8BlockScaledMMKernel,
CutlassFp8BlockScaledMMKernel,
TritonFp8BlockScaledMMKernel,
],
PlatformEnum.ROCM: [
AiterFp8BlockScaledMMKernel,
TritonFp8BlockScaledMMKernel,
],
}

# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
Expand All @@ -148,7 +173,7 @@
}

_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
_KernelConfigT = TypeVar("_KernelConfigT", bound=MMLinearLayerConfig)


def is_supported_and_can_implement_kernel(
Expand Down Expand Up @@ -240,31 +265,57 @@ def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
force_kernel: type[_KernelT] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
out_dtype=out_dtype,
)

kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
)
if activation_quant_key.scale.group_shape.is_per_group():
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_BLOCK_KERNELS, # type: ignore[misc]
force_kernel=force_kernel,
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)

if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
return kernel_type(
scaled_mm_linear_kernel_config,
)

return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
else:
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc]
force_kernel=force_kernel,
)

if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)

return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_scale_ub",
],
)


def init_int8_linear_kernel(
Expand Down
Loading