Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
n = b.shape[1]

if current_platform.is_rocm():
# TODO(luka) remove this once all uses go through the ScaledMMKernel path
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO should be addressed as part of this PR, or a new issue should be created to track it.

triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Type
from typing import Optional

import torch
import triton
Expand Down Expand Up @@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

import torch

from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.platforms import current_platform


@dataclass
class ScaledMMLinearLayerConfig:
Expand All @@ -17,9 +22,53 @@ class ScaledMMLinearLayerConfig:
class ScaledMMLinearKernel(ABC):

@classmethod
@abstractmethod
def is_supported(
cls,
compute_capability: Optional[int] = None
) -> Tuple[bool, Optional[str]]:
"""
Returns true if this kernel is supported on the current platform.
By default, a kernel is supported if the min_capability is reached
(it still has to override the get_min_capability method).
Kernels can also override this method for custom support checking.
"""
return cls._current_capability_supported(compute_capability)

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
"""
:return: minimum capability required for this kernel.
Override is_supported if min_capability is irrelevant.
"""
raise NotImplementedError(
"Either implement get_min_capability or override is_supported")

@classmethod
def _current_capability_supported(
cls,
compute_capability: Optional[int] = None
) -> Tuple[bool, Optional[str]]:
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor

# If the current platform uses compute_capability,
# make sure the kernel supports the compute capability.
if compute_capability is None:
raise ValueError(
f"Cannot determine if kernel {cls.__name__} is supported on "
f"platform {current_platform} as compute capability is not "
f"supported. Please override is_supported or remove the "
f"kernel from the list of kernels for the platform.")

kernel_min_capability = cls.get_min_capability()
if (kernel_min_capability > compute_capability):
return (False,
f"compute capability >={kernel_min_capability} required, "
f"{compute_capability} current")

return True, None

@classmethod
@abstractmethod
Expand All @@ -31,6 +80,7 @@ def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
w_s_param_name: str, i_s_param_name: str,
i_zp_param_name: str, azp_adj_param_name: str) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
Expand All @@ -53,14 +103,53 @@ def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp
Optional[torch.Tensor], # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name),
getattr(layer, self.i_s_name, None),
getattr(layer, self.i_zp_name, None),
getattr(layer, self.azp_adj_name, None),
)

def replace_parameter(self, layer: torch.nn.Module, name: str,
param: torch.nn.Parameter):
"""
This utility can replace a parameter with the new value.
"""

# Call free util function
replace_parameter(layer, name,
torch.nn.Parameter(param.data, requires_grad=False))
Comment on lines +118 to +126
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The param argument is type-hinted as torch.nn.Parameter, but this method is called with torch.Tensor arguments in CutlassScaledMMLinearKernel (e.g., weight_param.t(), input_scale_param.max(), results from maybe_unfuse_weight_scale and fuse_asymmetric_params). Calling .data on a plain torch.Tensor will raise an AttributeError. To fix this, the param argument should expect a torch.Tensor and the internal torch.nn.Parameter creation should use this tensor directly. The docstring should clarify that the input tensor is the data for the new parameter.

Suggested change
def replace_parameter(self, layer: torch.nn.Module, name: str,
param: torch.nn.Parameter):
"""
This utility can replace a parameter with the new value.
"""
# Call free util function
replace_parameter(layer, name,
torch.nn.Parameter(param.data, requires_grad=False))
def replace_parameter(self, layer: torch.nn.Module, name: str,
new_tensor_data: torch.Tensor):
"""This utility can replace a parameter. The new_tensor_data
will become the .data of the new torch.nn.Parameter."""
replace_parameter(layer, name,
torch.nn.Parameter(new_tensor_data, requires_grad=False))


def maybe_unfuse_weight_scale(self, layer: torch.nn.Module,
weight_scale_param: torch.nn.Parameter):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1

if is_fused_module and not self.config.is_channelwise:
weight_scale_param = convert_to_channelwise(
weight_scale_param, layer.logical_widths)

return weight_scale_param

def fuse_asymmetric_params(
self, input_scale_param: torch.nn.Parameter,
input_zp_param: torch.nn.Parameter
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fuse_asymmetric_params method calculates scale and azp which are torch.Tensor instances, not torch.nn.Parameter. The return type hint should be updated to reflect this.

Suggested change
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
) -> Tuple[torch.Tensor, torch.Tensor]:

# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zp_param.to(dtype=torch.int32)
range_max = (input_scale_param * (int8_traits.max - azps)).max()
range_min = (input_scale_param * (int8_traits.min - azps)).min()

scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)

# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)

return scale, azp
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ def choose_scaled_mm_linear_kernel(
Type[ScaledMMLinearKernel]: Chosen kernel.
"""

if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]

failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
Expand All @@ -60,25 +55,21 @@ def choose_scaled_mm_linear_kernel(
f' {kernel.__name__} disabled by environment variable')
continue

# If the current platform uses compute_capability,
# make sure the kernel supports the compute cability.
if compute_capability is not None:
kernel_min_capability = kernel.get_min_capability()
if (kernel_min_capability is not None
and kernel_min_capability > compute_capability):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel_min_capability}, current compute capability "
f"is {compute_capability}")
continue
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(
f' {kernel.__name__} not supported: {reason}')
continue

can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
if not can_implement:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)
f' {kernel.__name__} cannot implement given config ({config}): '
f'{failure_reason}')
continue

# Kernel enabled, supported, and can implement scheme!
return kernel

raise ValueError(
"Failed to find a kernel that can implement the "\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.platforms import current_platform

from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
Expand All @@ -21,65 +18,53 @@ def get_min_capability(cls) -> int:
return 75

@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:

if (not current_platform.is_cuda() and not current_platform.is_cpu()):
def is_supported(
cls,
compute_capability: Optional[int] = None
) -> Tuple[bool, Optional[str]]:
# Cutlass is also supported on CPU
if current_platform.is_cpu():
return True, ""

if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA or CPU."

# Defer to compute-capability-based support determination
return super().is_supported(compute_capability)

@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
# All schemes supported
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False))
weight_param = getattr(layer, self.w_q_name)
self.replace_parameter(layer, self.w_q_name, weight_param.t())

# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
w_scale_param = getattr(layer, self.w_s_name)
w_scale_param = self.maybe_unfuse_weight_scale(layer, w_scale_param)
self.replace_parameter(layer, self.w_s_name, w_scale_param)

# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
input_scale_param = getattr(layer, self.i_s_name)

if self.config.input_symmetric:
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False))
self.replace_parameter(layer, self.i_s_name,
input_scale_param.max())
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)

# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()

scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(scale, requires_grad=False))

# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
replace_parameter(layer, self.i_zp_name,
torch.nn.Parameter(azp, requires_grad=False))

i_scale, i_zp = self.fuse_asymmetric_params(
input_scale_param, input_zero_point)
self.replace_parameter(layer, self.i_s_name, i_scale)
self.replace_parameter(layer, self.i_zp_name, i_zp)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
Expand Down Expand Up @@ -110,11 +95,8 @@ def apply_weights(self,
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
i_s,
i_zp,
symmetric=symmetric)
sym = self.config.input_symmetric
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=sym)

if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm)
from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
Expand All @@ -16,25 +19,46 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
def get_min_capability(cls) -> int:
return 75

@classmethod
def is_supported(
cls,
compute_capability: Optional[int] = None
) -> Tuple[bool, Optional[str]]:
if current_platform.is_rocm() or current_platform.is_cuda():
return cls._current_capability_supported(compute_capability)

return False, "Triton scaled_mm requires running on ROCm or CUDA."

@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# TODO maybe this doesn't need to transpose the weight?
# Could also skip asymmetric-only paths
Comment on lines +42 to +43
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO should be addressed as part of this PR, or a new issue should be created to track it.

super().process_weights_after_loading(layer)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
w_q, w_s, i_s, _, _ = self._get_weight_params(layer)

# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.

# Only symmetric supported in triton_scaled_mm
x_q, x_s, _ = ops.scaled_int8_quant(x, i_s, symmetric=True)

return triton_scaled_mm(x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
bias=bias)
Loading