-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[WIP][FP8] ScaledMM refactor #19434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][FP8] ScaledMM refactor #19434
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,6 +6,11 @@ | |||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from vllm.model_executor.layers.quantization.utils import replace_parameter | ||||||||||||||||||||||||||||||||||
| from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | ||||||||||||||||||||||||||||||||||
| convert_to_channelwise) | ||||||||||||||||||||||||||||||||||
| from vllm.platforms import current_platform | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||
| class ScaledMMLinearLayerConfig: | ||||||||||||||||||||||||||||||||||
|
|
@@ -17,9 +22,53 @@ class ScaledMMLinearLayerConfig: | |||||||||||||||||||||||||||||||||
| class ScaledMMLinearKernel(ABC): | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||||||||||||||
| def is_supported( | ||||||||||||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||||||||||||
| compute_capability: Optional[int] = None | ||||||||||||||||||||||||||||||||||
| ) -> Tuple[bool, Optional[str]]: | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| Returns true if this kernel is supported on the current platform. | ||||||||||||||||||||||||||||||||||
| By default, a kernel is supported if the min_capability is reached | ||||||||||||||||||||||||||||||||||
| (it still has to override the get_min_capability method). | ||||||||||||||||||||||||||||||||||
| Kernels can also override this method for custom support checking. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| return cls._current_capability_supported(compute_capability) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||
| def get_min_capability(cls) -> int: | ||||||||||||||||||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| :return: minimum capability required for this kernel. | ||||||||||||||||||||||||||||||||||
| Override is_supported if min_capability is irrelevant. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||
| "Either implement get_min_capability or override is_supported") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||
| def _current_capability_supported( | ||||||||||||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||||||||||||
| compute_capability: Optional[int] = None | ||||||||||||||||||||||||||||||||||
| ) -> Tuple[bool, Optional[str]]: | ||||||||||||||||||||||||||||||||||
| if compute_capability is None: | ||||||||||||||||||||||||||||||||||
| _cc = current_platform.get_device_capability() | ||||||||||||||||||||||||||||||||||
| if _cc is not None: | ||||||||||||||||||||||||||||||||||
| compute_capability = _cc.major * 10 + _cc.minor | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # If the current platform uses compute_capability, | ||||||||||||||||||||||||||||||||||
| # make sure the kernel supports the compute capability. | ||||||||||||||||||||||||||||||||||
| if compute_capability is None: | ||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||
| f"Cannot determine if kernel {cls.__name__} is supported on " | ||||||||||||||||||||||||||||||||||
| f"platform {current_platform} as compute capability is not " | ||||||||||||||||||||||||||||||||||
| f"supported. Please override is_supported or remove the " | ||||||||||||||||||||||||||||||||||
| f"kernel from the list of kernels for the platform.") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| kernel_min_capability = cls.get_min_capability() | ||||||||||||||||||||||||||||||||||
| if (kernel_min_capability > compute_capability): | ||||||||||||||||||||||||||||||||||
| return (False, | ||||||||||||||||||||||||||||||||||
| f"compute capability >={kernel_min_capability} required, " | ||||||||||||||||||||||||||||||||||
| f"{compute_capability} current") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return True, None | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||||||||||||||
|
|
@@ -31,6 +80,7 @@ def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, | |||||||||||||||||||||||||||||||||
| w_s_param_name: str, i_s_param_name: str, | ||||||||||||||||||||||||||||||||||
| i_zp_param_name: str, azp_adj_param_name: str) -> None: | ||||||||||||||||||||||||||||||||||
| assert self.can_implement(c) | ||||||||||||||||||||||||||||||||||
| assert self.is_supported() | ||||||||||||||||||||||||||||||||||
| self.config = c | ||||||||||||||||||||||||||||||||||
| self.w_q_name = w_q_param_name | ||||||||||||||||||||||||||||||||||
| self.w_s_name = w_s_param_name | ||||||||||||||||||||||||||||||||||
|
|
@@ -53,14 +103,53 @@ def _get_weight_params( | |||||||||||||||||||||||||||||||||
| self, layer: torch.nn.Module) -> Tuple[ | ||||||||||||||||||||||||||||||||||
| torch.Tensor, # weight | ||||||||||||||||||||||||||||||||||
| torch.Tensor, # weight_scale | ||||||||||||||||||||||||||||||||||
| Optional[torch.Tensor], # input_scale, | ||||||||||||||||||||||||||||||||||
| Optional[torch.Tensor], # input_scale, | ||||||||||||||||||||||||||||||||||
| Optional[torch.Tensor], # input_zp | ||||||||||||||||||||||||||||||||||
| Optional[torch.Tensor], # azp_adj | ||||||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.w_q_name), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.w_s_name), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.i_s_name), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.i_zp_name), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.azp_adj_name), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.i_s_name, None), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.i_zp_name, None), | ||||||||||||||||||||||||||||||||||
| getattr(layer, self.azp_adj_name, None), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def replace_parameter(self, layer: torch.nn.Module, name: str, | ||||||||||||||||||||||||||||||||||
| param: torch.nn.Parameter): | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| This utility can replace a parameter with the new value. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Call free util function | ||||||||||||||||||||||||||||||||||
| replace_parameter(layer, name, | ||||||||||||||||||||||||||||||||||
| torch.nn.Parameter(param.data, requires_grad=False)) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+118
to
+126
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def maybe_unfuse_weight_scale(self, layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||
| weight_scale_param: torch.nn.Parameter): | ||||||||||||||||||||||||||||||||||
| # If we have a fused module (QKV, MLP) with per tensor scales (thus N | ||||||||||||||||||||||||||||||||||
| # scales being passed to the kernel), convert to the per-channel case. | ||||||||||||||||||||||||||||||||||
| is_fused_module = len(layer.logical_widths) > 1 | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if is_fused_module and not self.config.is_channelwise: | ||||||||||||||||||||||||||||||||||
| weight_scale_param = convert_to_channelwise( | ||||||||||||||||||||||||||||||||||
| weight_scale_param, layer.logical_widths) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return weight_scale_param | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def fuse_asymmetric_params( | ||||||||||||||||||||||||||||||||||
| self, input_scale_param: torch.nn.Parameter, | ||||||||||||||||||||||||||||||||||
| input_zp_param: torch.nn.Parameter | ||||||||||||||||||||||||||||||||||
| ) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||
| # reconstruct the ranges | ||||||||||||||||||||||||||||||||||
| int8_traits = torch.iinfo(torch.int8) | ||||||||||||||||||||||||||||||||||
| azps = input_zp_param.to(dtype=torch.int32) | ||||||||||||||||||||||||||||||||||
| range_max = (input_scale_param * (int8_traits.max - azps)).max() | ||||||||||||||||||||||||||||||||||
| range_min = (input_scale_param * (int8_traits.min - azps)).min() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # AZP loaded as int8 but used as int32 | ||||||||||||||||||||||||||||||||||
| azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return scale, azp | ||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501 | ||
| triton_scaled_mm) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from .cutlass import CutlassScaledMMLinearKernel | ||
|
|
@@ -16,25 +19,46 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): | |
| def get_min_capability(cls) -> int: | ||
| return 75 | ||
|
|
||
| @classmethod | ||
| def is_supported( | ||
| cls, | ||
| compute_capability: Optional[int] = None | ||
| ) -> Tuple[bool, Optional[str]]: | ||
| if current_platform.is_rocm() or current_platform.is_cuda(): | ||
| return cls._current_capability_supported(compute_capability) | ||
|
|
||
| return False, "Triton scaled_mm requires running on ROCm or CUDA." | ||
|
|
||
| @classmethod | ||
| def can_implement( | ||
| cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: | ||
| if current_platform.is_cpu(): | ||
| return ( | ||
| False, | ||
| "TritonScaledMMLinearKernel requires Triton which is not " + | ||
| "currently supported on CPU.") | ||
| if not c.input_symmetric: | ||
| return (False, | ||
| "TritonScaledMMLinearKernel only supports symmetric " + | ||
| "quantization.") | ||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| # TODO maybe this doesn't need to transpose the weight? | ||
| # Could also skip asymmetric-only paths | ||
|
Comment on lines
+42
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| super().process_weights_after_loading(layer) | ||
|
|
||
| def apply_weights(self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| return super().apply_weights(layer, x, bias) | ||
| w_q, w_s, i_s, _, _ = self._get_weight_params(layer) | ||
|
|
||
| # ops.scaled_int8_quant supports both dynamic and static quant: | ||
| # * dynamic, i_s is None and x_s computed from x. | ||
| # * static, i_s is scalar and x_s is i_s. | ||
|
|
||
| # Only symmetric supported in triton_scaled_mm | ||
| x_q, x_s, _ = ops.scaled_int8_quant(x, i_s, symmetric=True) | ||
|
|
||
| return triton_scaled_mm(x_q, | ||
| w_q, | ||
| scale_a=x_s, | ||
| scale_b=w_s, | ||
| out_dtype=x.dtype, | ||
| bias=bias) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO should be addressed as part of this PR, or a new issue should be created to track it.